Skip to content

Commit

Permalink
Merge pull request #146 from xylar/reuse-patches-in-planar-horiz
Browse files Browse the repository at this point in the history
Allow patches to be reused in `plot_horiz_field()`
  • Loading branch information
xylar authored Nov 2, 2023
2 parents 06cbb87 + a44f53b commit 0527fc5
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 56 deletions.
33 changes: 32 additions & 1 deletion docs/developers_guide/framework/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,42 @@ polygons characterized by the field values, accordingly.
An example function call that uses the default vertical level (top) is:

```python
cell_mask = ds_init.maxLevelCell >= 1
plot_horiz_field(config, ds, ds_mesh, 'normalVelocity',
'final_normalVelocity.png',
t_index=t_index,
vmin=-max_velocity, vmax=max_velocity,
cmap='cmo.balance', show_patch_edges=True)
cmap='cmo.balance', show_patch_edges=True,
cell_mask=cell_mask)
```

The `cell_mask` argument can be any field indicating which horizontal cells
are valid and which are not. A typical value for ocean plots is as shown
above: whether there are any active cells in the water column.

For increased efficiency, you can store the `patches` and `patch_mask` from
one call to `plot_horiz_field()` and reuse them in subsequent calls. The
`patches` and `patch_mask` are specific to the dimension (`nCell` or `nEdges`)
of the field to plot and the `cell_mask`. So separate `patches` and
`patch_mask` variables should be stored for as needed:

```python
cell_mask = ds_init.maxLevelCell >= 1
cell_patches, cell_patch_mask = plot_horiz_field(
ds=ds, ds_mesh=ds_mesh, field_name='ssh', out_file_name='plots/ssh.png',
vmin=-720, vmax=0, figsize=figsize, cell_mask=cell_mask)

plot_horiz_field(ds=ds, ds_mesh=ds_mesh, field_name='bottomDepth',
out_file_name='plots/bottomDepth.png', vmin=0, vmax=720,
figsize=figsize, patches=cell_patches,
patch_mask=cell_patch_mask)

edge_patches, edge_patch_mask = plot_horiz_field(
ds=ds, ds_mesh=ds_mesh, field_name='normalVelocity',
out_file_name='plots/normalVelocity.png', t_index=t_index,
vmin=-0.1, vmax=0.1, cmap='cmo.balance', cell_mask=cell_mask)

...
```

(dev-visualization-global)=
Expand Down
5 changes: 3 additions & 2 deletions polaris/ocean/tasks/baroclinic_channel/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ def run(self):

write_netcdf(ds, 'initial_state.nc')

cell_mask = ds.maxLevelCell >= 1
plot_horiz_field(ds, ds_mesh, 'temperature',
'initial_temperature.png')
'initial_temperature.png', cell_mask=cell_mask)
plot_horiz_field(ds, ds_mesh, 'normalVelocity',
'initial_normal_velocity.png', cmap='cmo.balance',
show_patch_edges=True)
show_patch_edges=True, cell_mask=cell_mask)
4 changes: 2 additions & 2 deletions polaris/ocean/tasks/baroclinic_channel/rpe/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ def run(self):
ax = axes[row_index]
ds = xr.open_dataset(f'output_nu_{nu:g}.nc', decode_times=False)
ds = ds.isel(nVertLevels=0)
ds['maxLevelCell'] = ds_init.maxLevelCell
times = ds.daysSinceStartOfSim.values
time_index = np.argmin(np.abs(times - time))

cell_mask = ds_init.maxLevelCell >= 1
plot_horiz_field(ds, ds_mesh, 'temperature', ax=ax,
cmap='cmo.thermal', t_index=time_index,
vmin=min_temp, vmax=max_temp,
cmap_title='SST (C)')
cmap_title='SST (C)', cell_mask=cell_mask)
ax.set_title(f'day {times[time_index]:g}, $\\nu_h=${nu:g}')

plt.savefig(output_filename)
8 changes: 5 additions & 3 deletions polaris/ocean/tasks/baroclinic_channel/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def run(self):
ds_mesh = xr.load_dataset('mesh.nc')
ds_init = xr.load_dataset('init.nc')
ds = xr.load_dataset('output.nc')
ds['maxLevelCell'] = ds_init.maxLevelCell
t_index = ds.sizes['Time'] - 1
cell_mask = ds_init.maxLevelCell >= 1
plot_horiz_field(ds, ds_mesh, 'temperature',
'final_temperature.png', t_index=t_index)
'final_temperature.png', t_index=t_index,
cell_mask=cell_mask)
max_velocity = np.max(np.abs(ds.normalVelocity.values))
plot_horiz_field(ds, ds_mesh, 'normalVelocity',
'final_normalVelocity.png',
t_index=t_index,
vmin=-max_velocity, vmax=max_velocity,
cmap='cmo.balance', show_patch_edges=True)
cmap='cmo.balance', show_patch_edges=True,
cell_mask=cell_mask)
15 changes: 9 additions & 6 deletions polaris/ocean/tasks/inertial_gravity_wave/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def run(self):
ds_mesh = xr.open_dataset(f'mesh_{mesh_name}.nc')
ds_init = xr.open_dataset(f'init_{mesh_name}.nc')
ds = xr.open_dataset(f'output_{mesh_name}.nc')
ds['maxLevelCell'] = ds_init.maxLevelCell
exact = ExactSolution(ds_init, config)

t0 = datetime.datetime.strptime(ds.xtime.values[0].decode(),
Expand All @@ -93,16 +92,20 @@ def run(self):
if error_range is None:
error_range = np.max(np.abs(ds.ssh_error.values))

plot_horiz_field(ds, ds_mesh, 'ssh', ax=axes[i, 0],
cmap='cmo.balance', t_index=ds.sizes["Time"] - 1,
vmin=-eta0, vmax=eta0, cmap_title="SSH (m)")
cell_mask = ds_init.maxLevelCell >= 1
patches, patch_mask = plot_horiz_field(
ds, ds_mesh, 'ssh', ax=axes[i, 0], cmap='cmo.balance',
t_index=ds.sizes["Time"] - 1, vmin=-eta0, vmax=eta0,
cmap_title="SSH (m)", cell_mask=cell_mask)
plot_horiz_field(ds, ds_mesh, 'ssh_exact', ax=axes[i, 1],
cmap='cmo.balance',
vmin=-eta0, vmax=eta0, cmap_title="SSH (m)")
vmin=-eta0, vmax=eta0, cmap_title="SSH (m)",
patches=patches, patch_mask=patch_mask)
plot_horiz_field(ds, ds_mesh, 'ssh_error', ax=axes[i, 2],
cmap='cmo.balance',
cmap_title=r"$\Delta$ SSH (m)",
vmin=-error_range, vmax=error_range)
vmin=-error_range, vmax=error_range,
patches=patches, patch_mask=patch_mask)

axes[0, 0].set_title('Numerical solution')
axes[0, 1].set_title('Analytical solution')
Expand Down
15 changes: 9 additions & 6 deletions polaris/ocean/tasks/manufactured_solution/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def run(self):
ds_mesh = xr.open_dataset(f'mesh_{mesh_name}.nc')
ds_init = xr.open_dataset(f'init_{mesh_name}.nc')
ds = xr.open_dataset(f'output_{mesh_name}.nc')
ds['maxLevelCell'] = ds_init.maxLevelCell
exact = ExactSolution(config, ds_init)

t0 = datetime.datetime.strptime(ds.xtime.values[0].decode(),
Expand All @@ -93,15 +92,19 @@ def run(self):
if error_range is None:
error_range = np.max(np.abs(ds.ssh_error.values))

plot_horiz_field(ds, ds_mesh, 'ssh', ax=axes[i, 0],
cmap='cmo.balance', t_index=ds.sizes["Time"] - 1,
vmin=-eta0, vmax=eta0, cmap_title="SSH")
cell_mask = ds_init.maxLevelCell >= 1
patches, patch_mask = plot_horiz_field(
ds, ds_mesh, 'ssh', ax=axes[i, 0], cmap='cmo.balance',
t_index=ds.sizes["Time"] - 1, vmin=-eta0, vmax=eta0,
cmap_title="SSH", cell_mask=cell_mask)
plot_horiz_field(ds, ds_mesh, 'ssh_exact', ax=axes[i, 1],
cmap='cmo.balance',
vmin=-eta0, vmax=eta0, cmap_title="SSH")
vmin=-eta0, vmax=eta0, cmap_title="SSH",
patches=patches, patch_mask=patch_mask)
plot_horiz_field(ds, ds_mesh, 'ssh_error', ax=axes[i, 2],
cmap='cmo.balance', cmap_title="dSSH",
vmin=-error_range, vmax=error_range)
vmin=-error_range, vmax=error_range,
patches=patches, patch_mask=patch_mask)

axes[0, 0].set_title('Numerical solution')
axes[0, 1].set_title('Analytical solution')
Expand Down
118 changes: 82 additions & 36 deletions polaris/viz/planar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901
ax=None, title=None, t_index=None, z_index=None,
vmin=None, vmax=None, show_patch_edges=False,
cmap=None, cmap_set_under=None, cmap_set_over=None,
cmap_scale='linear', cmap_title=None, figsize=None):
cmap_scale='linear', cmap_title=None, figsize=None,
vert_dim='nVertLevels', cell_mask=None, patches=None,
patch_mask=None):
"""
Plot a horizontal field from a planar domain using x,y coordinates at a
single time and depth slice.
Expand All @@ -27,13 +29,16 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901
ds_mesh : xarray.Dataset
A data set containing horizontal mesh variables
field_name: str
field_name : str
The name of the variable to plot, which must be present in ds
out_file_name: str
out_file_name : str, optional
The path to which the plot image should be written
title: str, optional
ax : matplotlib.axes.Axes
Axes to plot to if making a multi-panel figure
title : str, optional
The title of the plot
vmin : float, optional
Expand Down Expand Up @@ -65,12 +70,36 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901
cmap_scale : {'log', 'linear'}, optional
Whether the colormap is logarithmic or linear
cmap_title : str
cmap_title : str, optional
Title for color bar
figsize : tuple
figsize : tuple, optional
The width and height of the figure in inches. Default is determined
based on the aspect ratio of the domain.
vert_dim : str, optional
Name of the vertical dimension
cell_mask : numpy.ndarray, optional
A ``bool`` mask indicating where cells are valid, used to mask fields
on both cells and edges. Not used if ``patches`` and ``patch_mask``
are supplied
patches : list of numpy.ndarray, optional
Patches from a previous call to ``plot_horiz_field()``
patch_mask : numpy.ndarray, optional
A mask of where the field has patches from a previous call to
``plot_horiz_field()``
Returns
-------
patches : list of numpy.ndarray
Patches to reuse for future plots. Patches for cells can only be
reused for other plots on cells and similarly for edges.
patch_mask : numpy.ndarray
A mask used to select entries in the field that have patches
"""
use_mplstyle()

Expand All @@ -89,9 +118,6 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901
if title is None:
title = field_name

if 'maxLevelCell' not in ds:
raise ValueError(
'maxLevelCell must be added to ds before plotting.')
if field_name not in ds:
raise ValueError(
f'{field_name} must be present in ds before plotting.')
Expand All @@ -102,37 +128,48 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901
t_index = 0
if t_index is not None:
field = field.isel(Time=t_index)
if 'nVertLevels' in field.dims and z_index is None:
if vert_dim in field.dims and z_index is None:
z_index = 0
if z_index is not None:
field = field.isel(nVertLevels=z_index)

if 'nCells' in field.dims:
ocean_mask = ds.maxLevelCell - 1 >= 0
ocean_patches, ocean_mask = _compute_cell_patches(ds_mesh, ocean_mask)
elif 'nEdges' in field.dims:
ocean_mask = np.ones_like(field, dtype='bool')
ocean_mask = _remove_boundary_edges_from_mask(ds_mesh, ocean_mask)
ocean_patches, ocean_mask = _compute_edge_patches(ds_mesh, ocean_mask)
ocean_patches.set_array(field[ocean_mask])
field = field.isel({vert_dim: z_index})

if patches is not None:
if patch_mask is None:
raise ValueError('You must supply both patches and patch_mask '
'from a previous call to plot_horiz_field()')
else:
if cell_mask is None:
cell_mask = np.ones_like(field, type='bool')
if 'nCells' in field.dims:
patch_mask = cell_mask
patches, patch_mask = _compute_cell_patches(ds_mesh, patch_mask)
elif 'nEdges' in field.dims:
patch_mask = _edge_mask_from_cell_mask(ds_mesh, cell_mask)
patch_mask = _remove_boundary_edges_from_mask(ds_mesh, patch_mask)
patches, patch_mask = _compute_edge_patches(ds_mesh, patch_mask)
else:
raise ValueError('Cannot plot a field without dim nCells or '
'nEdges')
local_patches = PatchCollection(patches, alpha=1.)
local_patches.set_array(field[patch_mask])
if cmap is not None:
ocean_patches.set_cmap(cmap)
local_patches.set_cmap(cmap)
if cmap_set_under is not None:
current_cmap = ocean_patches.get_cmap()
current_cmap = local_patches.get_cmap()
current_cmap.set_under(cmap_set_under)
if cmap_set_over is not None:
current_cmap = ocean_patches.get_cmap()
current_cmap = local_patches.get_cmap()
current_cmap.set_over(cmap_set_over)

if show_patch_edges:
ocean_patches.set_edgecolor('black')
local_patches.set_edgecolor('black')
else:
ocean_patches.set_edgecolor('face')
ocean_patches.set_clim(vmin=vmin, vmax=vmax)
local_patches.set_edgecolor('face')
local_patches.set_clim(vmin=vmin, vmax=vmax)

if cmap_scale == 'log':
ocean_patches.set_norm(LogNorm(vmin=max(1e-10, vmin),
vmax=vmax, clip=False))
local_patches.set_norm(LogNorm(vmin=max(1e-10, vmin),
vmax=vmax, clip=False))

if figsize is None:
width = ds_mesh.xCell.max() - ds_mesh.xCell.min()
Expand All @@ -145,19 +182,32 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901
if create_fig:
plt.figure(figsize=figsize)
ax = plt.subplot(111)
ax.add_collection(ocean_patches)
ax.add_collection(local_patches)
ax.set_xlabel('x (km)')
ax.set_ylabel('y (km)')
ax.set_aspect('equal')
ax.autoscale(tight=True)
cbar = plt.colorbar(ocean_patches, extend='both', shrink=0.7, ax=ax)
cbar = plt.colorbar(local_patches, extend='both', shrink=0.7, ax=ax)
if cmap_title is not None:
cbar.set_label(cmap_title)
if create_fig:
plt.title(title)
plt.savefig(out_file_name, bbox_inches='tight', pad_inches=0.2)
plt.close()

return patches, patch_mask


def _edge_mask_from_cell_mask(ds, cell_mask):
cells_on_edge = ds.cellsOnEdge - 1
valid = cells_on_edge >= 0
# the edge mask is True if either adjacent cell is valid and its mask is
# True
edge_mask = np.logical_or(
np.logical_and(valid[:, 0], cell_mask[cells_on_edge[:, 0]]),
np.logical_and(valid[:, 1], cell_mask[cells_on_edge[:, 1]]))
return edge_mask


def _remove_boundary_edges_from_mask(ds, mask):
area_cell = ds.areaCell.values
Expand Down Expand Up @@ -221,9 +271,7 @@ def _compute_cell_patches(ds, mask):
polygon = Polygon(vertices, closed=True)
patches.append(polygon)

p = PatchCollection(patches, alpha=1.)

return p, mask
return patches, mask


def _compute_edge_patches(ds, mask):
Expand Down Expand Up @@ -252,6 +300,4 @@ def _compute_edge_patches(ds, mask):
polygon = Polygon(vertices, closed=True)
patches.append(polygon)

p = PatchCollection(patches, alpha=1.)

return p, mask
return patches, mask

0 comments on commit 0527fc5

Please sign in to comment.