Skip to content

Commit

Permalink
add mpl_args and stroke_width
Browse files Browse the repository at this point in the history
  • Loading branch information
kgoebber committed Sep 8, 2023
1 parent c123c31 commit d6e0808
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 70 deletions.
120 changes: 81 additions & 39 deletions src/metpy/plots/declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ def lookup_map_feature(feature_name):
return feat.with_scale(scaler)


def plot_kwargs(data):
def plot_kwargs(data, args):
"""Set the keyword arguments for MapPanel plotting."""
if hasattr(data.metpy, 'cartopy_crs'):
# Conditionally add cartopy transform if we are on a map.
kwargs = {'transform': data.metpy.cartopy_crs}
else:
kwargs = {}
kwargs.update(args)
return kwargs


Expand Down Expand Up @@ -102,8 +103,19 @@ def __dir__(self):
lambda name: not (name in dir(HasTraits) or name.startswith('_')),
dir(type(self))
)

mpl_args = Union([Dict(), Int(), Float(), Unicode()])

mpl_args = Dict(allow_none=True)
mpl_args.__doc__ = """Supply a dictionary of valid Matplotlib keyword arguments to modify
how the plot variable is drawn.
Using this attribute you must choose the appropriate keyword arguments (kwargs) based on
what you are plotting (e.g., contours, color-filled contours, image plot, etc.). This is
available for all plot types (ContourPlot, FilledContourPlot, RasterPlot, ImagePlot,
BarbPlot, ArrowPlot, PlotGeometry, and PlotObs). For PlotObs, the kwargs re those to
specify the StationPlot object. NOTE: Setting the mpl_args trait will override
any other trait that corresponds to a specific kwarg for the particular plot type
(e.g., linecolor, linewidth).
"""


class Panel(MetPyHasTraits):
Expand Down Expand Up @@ -938,20 +950,17 @@ def _build(self):
"""Build the plot by calling any plotting methods as necessary."""
x_like, y_like, imdata = self.plotdata

kwargs = plot_kwargs(imdata)
kwargs = plot_kwargs(imdata, self.mpl_args)

# If we're on a map, we use min/max for y and manually figure out origin to try to
# avoid upside down images created by images where y[0] > y[-1], as well as
# specifying the transform
kwargs['extent'] = (x_like[0], x_like[-1], y_like.min(), y_like.max())
kwargs['origin'] = 'upper' if y_like[0] > y_like[-1] else 'lower'
kwargs.setdefault('cmap', self._cmap_obj)
kwargs.setdefault('norm', self._norm_obj)

self.handle = self.parent.ax.imshow(
imdata,
cmap=self._cmap_obj,
norm=self._norm_obj,
**kwargs
)
self.handle = self.parent.ax.imshow(imdata, **kwargs)


@exporter.export
Expand Down Expand Up @@ -995,11 +1004,12 @@ def _build(self):
"""Build the plot by calling any plotting methods as necessary."""
x_like, y_like, imdata = self.plotdata

kwargs = plot_kwargs(imdata)
kwargs = plot_kwargs(imdata, self.mpl_args)
kwargs.setdefault('linewidths', self.linewidth)
kwargs.setdefault('colors', self.linecolor)
kwargs.setdefault('linestyles', self.linestyle)

self.handle = self.parent.ax.contour(x_like, y_like, imdata, self.contours,
colors=self.linecolor, linewidths=self.linewidth,
linestyles=self.linestyle, **kwargs)
self.handle = self.parent.ax.contour(x_like, y_like, imdata, self.contours, **kwargs)
if self.clabels:
self.handle.clabel(inline=1, fmt='%.0f', inline_spacing=8,
use_clabeltext=True, fontsize=self.label_fontsize)
Expand All @@ -1020,11 +1030,11 @@ def _build(self):
"""Build the plot by calling any plotting methods as necessary."""
x_like, y_like, imdata = self.plotdata

kwargs = plot_kwargs(imdata)
kwargs = plot_kwargs(imdata, self.mpl_args)
kwargs.setdefault('cmap', self._cmap_obj)
kwargs.setdefault('norm', self._norm_obj)

self.handle = self.parent.ax.contourf(x_like, y_like, imdata, self.contours,
cmap=self._cmap_obj, norm=self._norm_obj,
**kwargs)
self.handle = self.parent.ax.contourf(x_like, y_like, imdata, self.contours, **kwargs)


@exporter.export
Expand All @@ -1042,11 +1052,11 @@ def _build(self):
"""Build the raster plot by calling any plotting methods as necessary."""
x_like, y_like, imdata = self.plotdata

kwargs = plot_kwargs(imdata)
kwargs = plot_kwargs(imdata, self.mpl_args)
kwargs.setdefault('cmap', self._cmap_obj)
kwargs.setdefault('norm', self._norm_obj)

self.handle = self.parent.ax.pcolormesh(x_like, y_like, imdata,
cmap=self._cmap_obj, norm=self._norm_obj,
**kwargs)
self.handle = self.parent.ax.pcolormesh(x_like, y_like, imdata, **kwargs)


@exporter.export
Expand Down Expand Up @@ -1221,7 +1231,11 @@ def _build(self):
"""Build the plot by calling needed plotting methods as necessary."""
x_like, y_like, u, v = self.plotdata

kwargs = plot_kwargs(u)
kwargs = plot_kwargs(u, self.mpl_args)
kwargs.setdefault('color', self.color)
kwargs.setdefault('pivot', self.pivot)
kwargs.setdefault('length', self.barblength)
kwargs.setdefault('zorder', 2)

# Conditionally apply the proper transform
if 'transform' in kwargs and self.earth_relative:
Expand All @@ -1232,7 +1246,7 @@ def _build(self):
self.handle = self.parent.ax.barbs(
x_like[wind_slice], y_like[wind_slice],
u.values[wind_slice], v.values[wind_slice],
color=self.color, pivot=self.pivot, length=self.barblength, zorder=2, **kwargs)
**kwargs)


@exporter.export
Expand Down Expand Up @@ -1283,7 +1297,10 @@ def _build(self):
"""Build the plot by calling needed plotting methods as necessary."""
x_like, y_like, u, v = self.plotdata

kwargs = plot_kwargs(u)
kwargs = plot_kwargs(u, self.mpl_args)
kwargs.setdefault('color', self.color)
kwargs.setdefault('pivot', self.pivot)
kwargs.setdefault('scale', self.arrowscale)

# Conditionally apply the proper transform
if 'transform' in kwargs and self.earth_relative:
Expand All @@ -1294,7 +1311,7 @@ def _build(self):
self.handle = self.parent.ax.quiver(
x_like[wind_slice], y_like[wind_slice],
u.values[wind_slice], v.values[wind_slice],
color=self.color, pivot=self.pivot, scale=self.arrowscale, **kwargs)
**kwargs)

# The order here needs to match the order of the tuple
if self.arrowkey is not None:
Expand Down Expand Up @@ -1569,9 +1586,12 @@ def _build(self):
scale = 1. if self.parent._proj_obj == ccrs.PlateCarree() else 100000.
point_locs = self.parent._proj_obj.transform_points(ccrs.PlateCarree(), lon, lat)
subset = reduce_point_density(point_locs, self.reduce_points * scale)
kwargs = self.mpl_args
kwargs.setdefault('clip_on', True)
kwargs.setdefault('transform', ccrs.PlateCarree())
kwargs.setdefault('fontsize', self.fontsize)

self.handle = StationPlot(self.parent.ax, lon[subset], lat[subset], clip_on=True,
transform=ccrs.PlateCarree(), fontsize=self.fontsize)
self.handle = StationPlot(self.parent.ax, lon[subset], lat[subset], **kwargs)

for i, ob_type in enumerate(self.fields):
field_kwargs = {}
Expand Down Expand Up @@ -1669,6 +1689,17 @@ class PlotGeometry(MetPyHasTraits):
the sequence of colors as needed. Default value is black.
"""

stroke_width = Union([Instance(collections.abc.Iterable), Float()], default_value=[1],
allow_none=True)
stroke_width.__doc__ = """Stroke width(s) for polygons and lines.
A single integer or floating point value or collection of values representing the size of
the stroke width. If a collection, the first value corresponds to the first Shapely
object in `geometry`, the second value corresponds to the second Shapely object, and so on.
If `stroke_width` is shorter than `geometry`, `stroke_width` cycles back to the beginning,
repeating the sequence of values as needed. Default value is 1.
"""

marker = Unicode(default_value='.', allow_none=False)
marker.__doc__ = """Symbol used to denote points.
Expand Down Expand Up @@ -1847,27 +1878,38 @@ def _build(self):
else self.label_edgecolor)
self.label_facecolor = (['none'] if self.label_facecolor is None
else self.label_facecolor)
kwargs = self.mpl_args

# Each Shapely object is plotted separately with its corresponding colors and label
for geo_obj, stroke, fill, label, fontcolor, fontoutline in zip(
self.geometry, cycle(self.stroke), cycle(self.fill), cycle(self.labels),
cycle(self.label_facecolor), cycle(self.label_edgecolor)):
for geo_obj, stroke, strokewidth, fill, label, fontcolor, fontoutline in zip(
self.geometry, cycle(self.stroke), cycle(self.stroke_width), cycle(self.fill),
cycle(self.labels), cycle(self.label_facecolor), cycle(self.label_edgecolor)):
# Plot the Shapely object with the appropriate method and colors
if isinstance(geo_obj, (MultiPolygon, Polygon)):
self.parent.ax.add_geometries([geo_obj], edgecolor=stroke,
facecolor=fill, crs=ccrs.PlateCarree())
kwargs.setdefault('edgecolor', stroke)
kwargs.setdefault('linewidths', strokewidth)
kwargs.setdefault('facecolor', fill)
kwargs.setdefault('crs', ccrs.PlateCarree())
self.parent.ax.add_geometries([geo_obj], **kwargs)
elif isinstance(geo_obj, (MultiLineString, LineString)):
self.parent.ax.add_geometries([geo_obj], edgecolor=stroke,
facecolor='none', crs=ccrs.PlateCarree())
kwargs.setdefault('edgecolor', stroke)
kwargs.setdefault('linewidths', strokewidth)
kwargs.setdefault('facecolor', 'none')
kwargs.setdefault('crs', ccrs.PlateCarree())
self.parent.ax.add_geometries([geo_obj], **kwargs)
elif isinstance(geo_obj, MultiPoint):
kwargs.setdefault('color', fill)
kwargs.setdefault('marker', self.marker)
kwargs.setdefault('transform', ccrs.PlateCarree())
for point in geo_obj.geoms:
lon, lat = point.coords[0]
self.parent.ax.plot(lon, lat, color=fill, marker=self.marker,
transform=ccrs.PlateCarree())
self.parent.ax.plot(lon, lat, **kwargs)
elif isinstance(geo_obj, Point):
kwargs.setdefault('color', fill)
kwargs.setdefault('marker', self.marker)
kwargs.setdefault('transform', ccrs.PlateCarree())
lon, lat = geo_obj.coords[0]
self.parent.ax.plot(lon, lat, color=fill, marker=self.marker,
transform=ccrs.PlateCarree())
self.parent.ax.plot(lon, lat, **kwargs)

# Plot labels if provided
if label:
Expand Down
Binary file added tests/plots/baseline/test_colorfill_args.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
99 changes: 68 additions & 31 deletions tests/plots/test_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_declarative_image():

img = ImagePlot()
img.data = data.metpy.parse_cf('IR')
img.colormap = 'Greys_r'
img.mpl_args = {'cmap': 'Greys_r'}

panel = MapPanel()
panel.title = 'Test'
Expand Down Expand Up @@ -376,7 +376,7 @@ def test_declarative_layers_plot_options():
contour.level = 700 * units.hPa
contour.contours = 5
contour.linewidth = 1
contour.linecolor = 'grey'
contour.mpl_args = {'colors': 'grey'}

panel = MapPanel()
panel.area = 'us'
Expand Down Expand Up @@ -615,33 +615,6 @@ def test_colorfill():
return pc.figure


@pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.009)
@needs_cartopy
def test_colorfill_args():
"""Test that we can use ContourFillPlot."""
data = xr.open_dataset(get_test_data('narr_example.nc', as_file_obj=False))

contour = FilledContourPlot()
contour.data = data
contour.level = 700 * units.hPa
contour.field = 'Temperature'
contour.colormap = 'coolwarm'
contour.colorbar = 'vertical'
contour.mpl_args = {'alpha': 0.6}

panel = MapPanel()
panel.area = (-110, -60, 25, 55)
panel.layers = []
panel.plots = [contour]

pc = PanelContainer()
pc.panel = panel
pc.size = (12, 8)
pc.draw()

return pc.figure


@pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.02)
def test_colorfill_with_image_range(cfeature):
"""Test that we can use ContourFillPlot with image_range bounds."""
Expand Down Expand Up @@ -812,7 +785,8 @@ def test_declarative_barb_options():
barb.field = ['u_wind', 'v_wind']
barb.skip = (10, 10)
barb.color = 'blue'
barb.pivot = 'tip'
barb.pivot = 'middle'
barb.mpl_args = {'pivot': 'tip'}
barb.barblength = 6.5

panel = MapPanel()
Expand Down Expand Up @@ -841,7 +815,8 @@ def test_declarative_arrowplot():
arrows.field = ['u_wind', 'v_wind']
arrows.skip = (10, 10)
arrows.color = 'blue'
arrows.pivot = 'mid'
arrows.pivot = 'tip'
arrows.mpl_args = {'pivot': 'mid'}
arrows.arrowscale = 1000

panel = MapPanel()
Expand Down Expand Up @@ -1314,6 +1289,39 @@ def test_declarative_sfc_obs(ccrs):
return pc.figure


@pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.025)
def test_declarative_sfc_obs_args(ccrs):
"""Test making a surface observation plot with mpl arguments."""
data = pd.read_csv(get_test_data('SFC_obs.csv', as_file_obj=False),
infer_datetime_format=True, parse_dates=['valid'])

obs = PlotObs()
obs.data = data
obs.time = datetime(1993, 3, 12, 12)
obs.time_window = timedelta(minutes=15)
obs.level = None
obs.fields = ['tmpf']
obs.colors = ['black']
obs.mpl_args = {'fontsize': 12}

# Panel for plot with Map features
panel = MapPanel()
panel.layout = (1, 1, 1)
panel.projection = ccrs.PlateCarree()
panel.area = 'in'
panel.layers = ['states']
panel.plots = [obs]

# Bringing it all together
pc = PanelContainer()
pc.size = (10, 10)
pc.panels = [panel]

pc.draw()

return pc.figure


@pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.016)
@needs_cartopy
def test_declarative_sfc_text():
Expand Down Expand Up @@ -1815,6 +1823,33 @@ def test_declarative_raster():
return pc.figure


@pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.02)
@needs_cartopy
def test_declarative_raster_options():
"""Test making a raster plot."""
data = xr.open_dataset(get_test_data('narr_example.nc', as_file_obj=False))

Check warning

Code scanning / CodeQL

File is not always closed Warning test

File is opened but is not closed.

raster = RasterPlot()
raster.data = data
raster.colormap = 'viridis'
raster.field = 'Temperature'
raster.level = 700 * units.hPa
raster.mpl_args = {'alpha': 1, 'cmap': 'coolwarm'}

panel = MapPanel()
panel.area = 'us'
panel.projection = 'lcc'
panel.layers = ['coastline']
panel.plots = [raster]

pc = PanelContainer()
pc.size = (8.0, 8)
pc.panels = [panel]
pc.draw()

return pc.figure


@pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.607)
@needs_cartopy
def test_declarative_region_modifier_zoom_in():
Expand Down Expand Up @@ -1975,6 +2010,7 @@ def test_declarative_plot_geometry_polygons():
geo = PlotGeometry()
geo.geometry = [slgt_risk_polygon, enh_risk_polygon]
geo.stroke = ['#DDAA00', '#FF6600']
geo.stroke_width = [1]
geo.fill = None
geo.labels = ['SLGT', 'ENH']
geo.label_facecolor = ['#FFE066', '#FFA366']
Expand Down Expand Up @@ -2019,6 +2055,7 @@ def test_declarative_plot_geometry_lines(ccrs):
geo.stroke = 'green'
geo.labels = ['Irma', '+/- 0.25 deg latitude']
geo.label_facecolor = None
geo.mpl_args = {'linewidth': 1}

# Place plot in a panel and container
panel = MapPanel()
Expand Down

0 comments on commit d6e0808

Please sign in to comment.