diff --git a/wntr/graphics/network.py b/wntr/graphics/network.py index 8ef9b396c..a974b7c2f 100644 --- a/wntr/graphics/network.py +++ b/wntr/graphics/network.py @@ -9,6 +9,7 @@ import matplotlib.pyplot as plt import matplotlib.path as mpath from matplotlib import animation +import matplotlib as mpl import numpy as np try: @@ -218,14 +219,15 @@ def plot_network( if isinstance(link_attribute, list): link_kwds["column"] = "_attribute" link_kwds["cmap"] = custom_colormap(2,("red", "red")) - link_kwds["legend"] = False + link_cbar = False elif isinstance(link_attribute, (dict, pd.Series, str)): link_kwds["cmap"] = link_cmap link_kwds["vmin"] = link_range[0] link_kwds["vmax"] = link_range[1] - link_kwds["legend"] = add_colorbar + link_cbar = add_colorbar else: link_kwds["color"] = "black" + link_cbar = False link_kwds["linewidth"] = link_width link_kwds["alpha"] = link_alpha @@ -245,15 +247,16 @@ def plot_network( if isinstance(node_attribute, list): node_kwds["cmap"] = custom_colormap(2,("red", "red")) - node_kwds["legend"] = False + node_cbar = False elif isinstance(node_attribute, (dict, pd.Series, str)): node_kwds["cmap"] = node_cmap node_kwds["vmin"] = node_range[0] node_kwds["vmax"] = node_range[1] - node_kwds["legend"] = add_colorbar + node_cbar = add_colorbar else: node_kwds["color"] = "black" - + node_cbar = False + node_kwds["alpha"] = node_alpha node_kwds["markersize"] = node_size @@ -264,25 +267,37 @@ def plot_network( # plot nodes - each type is plotted separately to allow for different marker types node_gdf[node_gdf.node_type == "Junction"].plot( - ax=ax, aspect=aspect, zorder=3, legend_kwds=node_cbar_kwds, **node_kwds) + ax=ax, aspect=aspect, zorder=3, legend=False, **node_kwds) - # turn off legend for subsequent node plots - node_kwds["legend"] = False + if node_cbar: + norm = plt.Normalize(vmin=node_kwds["vmin"], vmax=node_kwds["vmax"],) + sm = mpl.cm.ScalarMappable(cmap=node_kwds["cmap"], norm=norm) + sm.set_array([]) + + node_cbar = ax.figure.colorbar(sm, ax=ax, **node_cbar_kwds) node_kwds["markersize"] = node_size * 2.0 node_gdf[node_gdf.node_type == "Tank"].plot( - ax=ax, aspect=aspect, zorder=4, marker=tank_marker, **node_kwds) + ax=ax, aspect=aspect, zorder=4, marker=tank_marker, legend=False, **node_kwds) node_kwds["markersize"] = node_size * 3.0 node_gdf[node_gdf.node_type == "Reservoir"].plot( - ax=ax, aspect=aspect, zorder=5, marker=reservoir_marker, **node_kwds) + ax=ax, aspect=aspect, zorder=5, marker=reservoir_marker, legend=False, **node_kwds) # plot links link_gdf.plot( - ax=ax, aspect=aspect, zorder=1, **background_link_kwds) + ax=ax, aspect=aspect, zorder=1, legend=False, **background_link_kwds) link_gdf.plot( - ax=ax, aspect=aspect, zorder=2, legend_kwds=link_cbar_kwds, **link_kwds) + ax=ax, aspect=aspect, zorder=2, legend=False, **link_kwds) + + # Create a ScalarMappable for the colorbar + if link_cbar: + norm = plt.Normalize(vmin=link_kwds["vmin"], vmax=link_kwds["vmax"]) + sm = mpl.cm.ScalarMappable(cmap=link_kwds["cmap"], norm=norm) + sm.set_array([]) # Needed to create an empty array for the colorbar + + ax.figure.colorbar(sm, ax=ax, **link_cbar_kwds) # Adjusts size and position of colorbar if len(wn_gis.pumps) >0: # wn_gis.pumps.plot(ax=ax, color="purple", aspect=aspect)