From 046006a2b848ea169fd0621415f7a0255c7a3ffd Mon Sep 17 00:00:00 2001 From: kbonney Date: Tue, 3 Dec 2024 10:20:45 -0500 Subject: [PATCH] update plot_network to use _format functions instead of _prepare function --- wntr/graphics/network.py | 161 ++++++++++++++---------------------- wntr/tests/test_graphics.py | 11 +-- 2 files changed, 68 insertions(+), 104 deletions(-) diff --git a/wntr/graphics/network.py b/wntr/graphics/network.py index 2433b0003..7b49704fd 100644 --- a/wntr/graphics/network.py +++ b/wntr/graphics/network.py @@ -43,33 +43,6 @@ def _get_angle(line, loc=0.5): angle = math.degrees(angle) return angle - -def _prepare_attribute(attribute, gdf): - kwds = {} - if attribute is not None: - # if dict convert to a series - if isinstance(attribute, dict): - attribute = pd.Series(attribute) - # if series add as a column to link gdf - if isinstance(attribute, pd.Series): - gdf["_attribute"] = attribute - kwds["column"] = "_attribute" - # if list, create new boolean column that captures which indices are in the list - elif isinstance(attribute, list): - gdf["_attribute"] = np.nan - gdf.loc[gdf.index.isin(attribute), "_attribute"] = 1 - kwds["column"] = "_attribute" - # if str, assert that column name exists - elif isinstance(attribute, str): - if attribute not in gdf.columns: - raise KeyError(f"attribute {attribute} does not exist.") - kwds["column"] = attribute - else: - raise TypeError("attribute must be dict, Series, list, or str") - else: - kwds["color"] = "black" - return kwds - def _format_node_attribute(node_attribute, wn): if isinstance(node_attribute, str): @@ -213,8 +186,7 @@ def plot_network( if node_colorbar_label is None and isinstance(node_attribute, str): node_colorbar_label = node_attribute if link_colorbar_label is None and isinstance(link_attribute, str): - link_colorbar_label = link_attribute - + link_colorbar_label = link_attribute wn_gis = wn.to_gis() # add node_type so that node assets can be plotted separately @@ -224,37 +196,71 @@ def plot_network( link_gdf = pd.concat((wn_gis.pipes, wn_gis.pumps, wn_gis.valves)) node_gdf = pd.concat((wn_gis.junctions, wn_gis.tanks, wn_gis.reservoirs)) - # process link attribute - # link_kwds = {} - link_kwds = _prepare_attribute(link_attribute, link_gdf) # Node attribute - # if node_attribute is not None: - # if isinstance(node_attribute, list): - # node_cmap = 'Reds' - # add_colorbar = False - # node_attribute = _format_node_attribute(node_attribute, wn) - # else: - # add_colorbar = False - # link_ - - # handle cbar/cmap - if isinstance(link_attribute, list): - link_kwds["cmap"] = custom_colormap(2,["red", "red"]) - link_cbar = False - elif isinstance(link_attribute, (dict, pd.Series, str)): - link_kwds["cmap"] = link_cmap + node_kwds = {} + node_cbar = add_colorbar + if node_attribute is not None: + node_gdf["_attribute"] = _format_node_attribute(node_attribute, wn) + node_kwds["column"] = "_attribute" - link_attribute_values = link_gdf[link_kwds["column"]] - if link_range[0] is None: - link_kwds["vmin"] = np.nanmin(link_attribute_values) + # handle cbar/cmap + if isinstance(node_attribute, list): + node_kwds["cmap"] = custom_colormap(2,["red", "red"]) + node_cbar = False + elif isinstance(node_attribute, (dict, pd.Series, str)): + node_kwds["cmap"] = node_cmap + + # manually extract min/max if no range is given + node_attribute_values = node_gdf[node_kwds["column"]] + if node_range[0] is None: + node_kwds["vmin"] = np.nanmin(node_attribute_values) + else: + node_kwds["vmin"] = node_range[0] + if node_range[1] is None: + node_kwds["vmax"] = np.nanmax(node_attribute_values) + else: + node_kwds["vmax"] = node_range[1] else: - link_kwds["vmin"] = link_range[0] - if link_range[1] is None: - link_kwds["vmax"] = np.nanmax(link_attribute_values) + raise TypeError("attribute must be dict, Series, list, or str") + else: + node_kwds["color"] = "black" + node_cbar = False + + node_kwds["alpha"] = node_alpha + node_kwds["markersize"] = node_size + + node_cbar_kwds = {} + node_cbar_kwds["shrink"] = 0.5 + node_cbar_kwds["pad"] = 0.0 + node_cbar_kwds["alpha"] = node_alpha + node_cbar_kwds["label"] = node_colorbar_label + + # Link attribute + link_kwds = {} + link_cbar = add_colorbar + if link_attribute is not None: + link_gdf["_attribute"] = pd.Series(_format_link_attribute(link_attribute, wn)) + link_kwds["column"] = "_attribute" + + # handle cbar/cmap + if isinstance(link_attribute, list): + link_kwds["cmap"] = custom_colormap(2,["red", "red"]) + link_cbar = False + elif isinstance(link_attribute, (dict, pd.Series, str)): + link_kwds["cmap"] = link_cmap + + # manually extract min/max if no range is given + link_attribute_values = link_gdf[link_kwds["column"]] + if link_range[0] is None: + link_kwds["vmin"] = np.nanmin(link_attribute_values) + else: + link_kwds["vmin"] = link_range[0] + if link_range[1] is None: + link_kwds["vmax"] = np.nanmax(link_attribute_values) + else: + link_kwds["vmax"] = link_range[1] else: - link_kwds["vmax"] = link_range[1] - - link_cbar = add_colorbar + raise TypeError("attribute must be dict, Series, list, or str") else: link_kwds["color"] = "black" link_cbar = False @@ -264,7 +270,7 @@ def plot_network( background_link_kwds = {} background_link_kwds["color"] = "grey" - background_link_kwds["linewidth"] = link_width + background_link_kwds["linewidth"] = link_width / 2 background_link_kwds["alpha"] = link_alpha link_cbar_kwds = {} @@ -273,48 +279,12 @@ def plot_network( link_cbar_kwds["label"] = link_colorbar_label link_cbar_kwds["alpha"] = link_alpha - # process node attribute - node_kwds = _prepare_attribute(node_attribute, node_gdf) - - if isinstance(node_attribute, list): - node_kwds["cmap"] = custom_colormap(2,["red", "red"]) - node_cbar = False - elif isinstance(node_attribute, (dict, pd.Series, str)): - node_kwds["cmap"] = node_cmap - - node_attribute_values = node_gdf[node_kwds["column"]] - if node_range[0] is None: - node_kwds["vmin"] = np.nanmin(node_attribute_values) - else: - node_kwds["vmin"] = node_range[0] - if node_range[1] is None: - node_kwds["vmax"] = np.nanmax(node_attribute_values) - else: - node_kwds["vmax"] = node_range[1] - - node_cbar = add_colorbar - else: - node_kwds["color"] = "black" - node_cbar = False - - node_kwds["alpha"] = node_alpha - node_kwds["markersize"] = node_size - - node_cbar_kwds = {} - node_cbar_kwds["shrink"] = 0.5 - node_cbar_kwds["pad"] = 0.0 - node_cbar_kwds["alpha"] = node_alpha - node_cbar_kwds["label"] = node_colorbar_label - - # # prepare legend item list - # legend_items = [] missing_node_kwds={"color": "black"} - missing_link_kwds={"color": "black", "linewidth": link_width / 2} + missing_link_kwds={"color": "black"} # plot nodes - each type is plotted separately to allow for different marker types node_gdf[node_gdf.node_type == "Junction"].plot( ax=ax, aspect=aspect, zorder=3, legend=False, label="Junction", missing_kwds=missing_node_kwds, **node_kwds) - # legend_items.append(plt.Line2D([0], [0], marker='o', color='w', label='Junctions', markerfacecolor='blue', markersize=6)) node_kwds["markersize"] = node_size * 2.0 node_gdf[node_gdf.node_type == "Tank"].plot( @@ -362,9 +332,6 @@ def plot_network( angle = row["_angle"] ax.scatter(x,y, color="black", s=50, marker=(3,0, angle-90)) - # NOTE: The coloring on the symbols will change based on the colors of the underlying object. - # If this isn't desired behavior, handles and labels can be build manually using: - # handle = plt.Line2D([0], [0], marker='o', color='w', label='Junctions', markerfacecolor='black', markersize=6) if legend: handles, labels = ax.get_legend_handles_labels() leg = ax.legend(handles, labels, loc='upper right', title="Legend") diff --git a/wntr/tests/test_graphics.py b/wntr/tests/test_graphics.py index f5c5c6c7c..8caed7e45 100644 --- a/wntr/tests/test_graphics.py +++ b/wntr/tests/test_graphics.py @@ -143,7 +143,6 @@ def test_plot_network_options(self): cmap = matplotlib.colormaps['viridis'] - inp_file = join(ex_datadir, "Net6.inp") wn = wntr.network.WaterNetworkModel(inp_file) @@ -163,14 +162,12 @@ def test_plot_network_options(self): "node_range": [0,20], "node_alpha": 0.5, "node_colorbar_label": "test_label"}, - {"node_attribute": "elevation", - "node_range": [0,1], - "node_alpha": 0.5, - "node_colorbar_label": "test_label"}, {"link_attribute": "diameter", - "link_range": [0,1], + "link_range": [0,None], "link_alpha": 0.5, "link_colorbar_label": "test_label"}, + {"link_attribute": "diameter", + "node_attribute": "elevation"}, {"node_labels": True, "link_labels": True}, {"node_attribute": "elevation", @@ -198,7 +195,7 @@ def test_plot_network_options(self): if compare: fig, ax = plt.subplots(1,2) wntr.graphics.plot_network(wn, ax=ax[0], title="GIS plot_network", **kwargs) - wntr.graphics.plot_network_nx(wn, ax=ax[1], title="NX plot_network", **kwargs) + plot_network_nx(wn, ax=ax[1], title="NX plot_network", **kwargs) fig.savefig(filename, format="png") plt.close(fig) else: