Skip to content

Commit

Permalink
update plot_network to use _format functions instead of _prepare func…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
kbonney committed Dec 3, 2024
1 parent 4a8f5a4 commit 046006a
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 104 deletions.
161 changes: 64 additions & 97 deletions wntr/graphics/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,6 @@ def _get_angle(line, loc=0.5):
angle = math.degrees(angle)
return angle


def _prepare_attribute(attribute, gdf):
kwds = {}
if attribute is not None:
# if dict convert to a series
if isinstance(attribute, dict):
attribute = pd.Series(attribute)
# if series add as a column to link gdf
if isinstance(attribute, pd.Series):
gdf["_attribute"] = attribute
kwds["column"] = "_attribute"
# if list, create new boolean column that captures which indices are in the list
elif isinstance(attribute, list):
gdf["_attribute"] = np.nan
gdf.loc[gdf.index.isin(attribute), "_attribute"] = 1
kwds["column"] = "_attribute"
# if str, assert that column name exists
elif isinstance(attribute, str):
if attribute not in gdf.columns:
raise KeyError(f"attribute {attribute} does not exist.")
kwds["column"] = attribute
else:
raise TypeError("attribute must be dict, Series, list, or str")
else:
kwds["color"] = "black"
return kwds

def _format_node_attribute(node_attribute, wn):

if isinstance(node_attribute, str):
Expand Down Expand Up @@ -213,8 +186,7 @@ def plot_network(
if node_colorbar_label is None and isinstance(node_attribute, str):
node_colorbar_label = node_attribute
if link_colorbar_label is None and isinstance(link_attribute, str):
link_colorbar_label = link_attribute

link_colorbar_label = link_attribute

wn_gis = wn.to_gis()
# add node_type so that node assets can be plotted separately
Expand All @@ -224,37 +196,71 @@ def plot_network(
link_gdf = pd.concat((wn_gis.pipes, wn_gis.pumps, wn_gis.valves))
node_gdf = pd.concat((wn_gis.junctions, wn_gis.tanks, wn_gis.reservoirs))

# process link attribute
# link_kwds = {}
link_kwds = _prepare_attribute(link_attribute, link_gdf)
# Node attribute
# if node_attribute is not None:
# if isinstance(node_attribute, list):
# node_cmap = 'Reds'
# add_colorbar = False
# node_attribute = _format_node_attribute(node_attribute, wn)
# else:
# add_colorbar = False
# link_

# handle cbar/cmap
if isinstance(link_attribute, list):
link_kwds["cmap"] = custom_colormap(2,["red", "red"])
link_cbar = False
elif isinstance(link_attribute, (dict, pd.Series, str)):
link_kwds["cmap"] = link_cmap
node_kwds = {}
node_cbar = add_colorbar
if node_attribute is not None:
node_gdf["_attribute"] = _format_node_attribute(node_attribute, wn)
node_kwds["column"] = "_attribute"

link_attribute_values = link_gdf[link_kwds["column"]]
if link_range[0] is None:
link_kwds["vmin"] = np.nanmin(link_attribute_values)
# handle cbar/cmap
if isinstance(node_attribute, list):
node_kwds["cmap"] = custom_colormap(2,["red", "red"])
node_cbar = False
elif isinstance(node_attribute, (dict, pd.Series, str)):
node_kwds["cmap"] = node_cmap

# manually extract min/max if no range is given
node_attribute_values = node_gdf[node_kwds["column"]]
if node_range[0] is None:
node_kwds["vmin"] = np.nanmin(node_attribute_values)
else:
node_kwds["vmin"] = node_range[0]
if node_range[1] is None:
node_kwds["vmax"] = np.nanmax(node_attribute_values)
else:
node_kwds["vmax"] = node_range[1]
else:
link_kwds["vmin"] = link_range[0]
if link_range[1] is None:
link_kwds["vmax"] = np.nanmax(link_attribute_values)
raise TypeError("attribute must be dict, Series, list, or str")
else:
node_kwds["color"] = "black"
node_cbar = False

node_kwds["alpha"] = node_alpha
node_kwds["markersize"] = node_size

node_cbar_kwds = {}
node_cbar_kwds["shrink"] = 0.5
node_cbar_kwds["pad"] = 0.0
node_cbar_kwds["alpha"] = node_alpha
node_cbar_kwds["label"] = node_colorbar_label

# Link attribute
link_kwds = {}
link_cbar = add_colorbar
if link_attribute is not None:
link_gdf["_attribute"] = pd.Series(_format_link_attribute(link_attribute, wn))
link_kwds["column"] = "_attribute"

# handle cbar/cmap
if isinstance(link_attribute, list):
link_kwds["cmap"] = custom_colormap(2,["red", "red"])
link_cbar = False
elif isinstance(link_attribute, (dict, pd.Series, str)):
link_kwds["cmap"] = link_cmap

# manually extract min/max if no range is given
link_attribute_values = link_gdf[link_kwds["column"]]
if link_range[0] is None:
link_kwds["vmin"] = np.nanmin(link_attribute_values)
else:
link_kwds["vmin"] = link_range[0]
if link_range[1] is None:
link_kwds["vmax"] = np.nanmax(link_attribute_values)
else:
link_kwds["vmax"] = link_range[1]
else:
link_kwds["vmax"] = link_range[1]

link_cbar = add_colorbar
raise TypeError("attribute must be dict, Series, list, or str")
else:
link_kwds["color"] = "black"
link_cbar = False
Expand All @@ -264,7 +270,7 @@ def plot_network(

background_link_kwds = {}
background_link_kwds["color"] = "grey"
background_link_kwds["linewidth"] = link_width
background_link_kwds["linewidth"] = link_width / 2
background_link_kwds["alpha"] = link_alpha

link_cbar_kwds = {}
Expand All @@ -273,48 +279,12 @@ def plot_network(
link_cbar_kwds["label"] = link_colorbar_label
link_cbar_kwds["alpha"] = link_alpha

# process node attribute
node_kwds = _prepare_attribute(node_attribute, node_gdf)

if isinstance(node_attribute, list):
node_kwds["cmap"] = custom_colormap(2,["red", "red"])
node_cbar = False
elif isinstance(node_attribute, (dict, pd.Series, str)):
node_kwds["cmap"] = node_cmap

node_attribute_values = node_gdf[node_kwds["column"]]
if node_range[0] is None:
node_kwds["vmin"] = np.nanmin(node_attribute_values)
else:
node_kwds["vmin"] = node_range[0]
if node_range[1] is None:
node_kwds["vmax"] = np.nanmax(node_attribute_values)
else:
node_kwds["vmax"] = node_range[1]

node_cbar = add_colorbar
else:
node_kwds["color"] = "black"
node_cbar = False

node_kwds["alpha"] = node_alpha
node_kwds["markersize"] = node_size

node_cbar_kwds = {}
node_cbar_kwds["shrink"] = 0.5
node_cbar_kwds["pad"] = 0.0
node_cbar_kwds["alpha"] = node_alpha
node_cbar_kwds["label"] = node_colorbar_label

# # prepare legend item list
# legend_items = []
missing_node_kwds={"color": "black"}
missing_link_kwds={"color": "black", "linewidth": link_width / 2}
missing_link_kwds={"color": "black"}

# plot nodes - each type is plotted separately to allow for different marker types
node_gdf[node_gdf.node_type == "Junction"].plot(
ax=ax, aspect=aspect, zorder=3, legend=False, label="Junction", missing_kwds=missing_node_kwds, **node_kwds)
# legend_items.append(plt.Line2D([0], [0], marker='o', color='w', label='Junctions', markerfacecolor='blue', markersize=6))

node_kwds["markersize"] = node_size * 2.0
node_gdf[node_gdf.node_type == "Tank"].plot(
Expand Down Expand Up @@ -362,9 +332,6 @@ def plot_network(
angle = row["_angle"]
ax.scatter(x,y, color="black", s=50, marker=(3,0, angle-90))

# NOTE: The coloring on the symbols will change based on the colors of the underlying object.
# If this isn't desired behavior, handles and labels can be build manually using:
# handle = plt.Line2D([0], [0], marker='o', color='w', label='Junctions', markerfacecolor='black', markersize=6)
if legend:
handles, labels = ax.get_legend_handles_labels()
leg = ax.legend(handles, labels, loc='upper right', title="Legend")
Expand Down
11 changes: 4 additions & 7 deletions wntr/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def test_plot_network_options(self):

cmap = matplotlib.colormaps['viridis']


inp_file = join(ex_datadir, "Net6.inp")
wn = wntr.network.WaterNetworkModel(inp_file)

Expand All @@ -163,14 +162,12 @@ def test_plot_network_options(self):
"node_range": [0,20],
"node_alpha": 0.5,
"node_colorbar_label": "test_label"},
{"node_attribute": "elevation",
"node_range": [0,1],
"node_alpha": 0.5,
"node_colorbar_label": "test_label"},
{"link_attribute": "diameter",
"link_range": [0,1],
"link_range": [0,None],
"link_alpha": 0.5,
"link_colorbar_label": "test_label"},
{"link_attribute": "diameter",
"node_attribute": "elevation"},
{"node_labels": True,
"link_labels": True},
{"node_attribute": "elevation",
Expand Down Expand Up @@ -198,7 +195,7 @@ def test_plot_network_options(self):
if compare:
fig, ax = plt.subplots(1,2)
wntr.graphics.plot_network(wn, ax=ax[0], title="GIS plot_network", **kwargs)
wntr.graphics.plot_network_nx(wn, ax=ax[1], title="NX plot_network", **kwargs)
plot_network_nx(wn, ax=ax[1], title="NX plot_network", **kwargs)
fig.savefig(filename, format="png")
plt.close(fig)
else:
Expand Down

0 comments on commit 046006a

Please sign in to comment.