diff --git a/docs/developers_guide/framework/visualization.md b/docs/developers_guide/framework/visualization.md index 61a413440..7c5c4292d 100644 --- a/docs/developers_guide/framework/visualization.md +++ b/docs/developers_guide/framework/visualization.md @@ -46,11 +46,42 @@ polygons characterized by the field values, accordingly. An example function call that uses the default vertical level (top) is: ```python +cell_mask = ds_init.maxLevelCell >= 1 plot_horiz_field(config, ds, ds_mesh, 'normalVelocity', 'final_normalVelocity.png', t_index=t_index, vmin=-max_velocity, vmax=max_velocity, - cmap='cmo.balance', show_patch_edges=True) + cmap='cmo.balance', show_patch_edges=True, + cell_mask=cell_mask) +``` + +The `cell_mask` argument can be any field indicating which horizontal cells +are valid and which are not. A typical value for ocean plots is as shown +above: whether there are any active cells in the water column. + +For increased efficiency, you can store the `patches` and `patch_mask` from +one call to `plot_horiz_field()` and reuse them in subsequent calls. The +`patches` and `patch_mask` are specific to the dimension (`nCell` or `nEdges`) +of the field to plot and the `cell_mask`. So separate `patches` and +`patch_mask` variables should be stored for as needed: + +```python +cell_mask = ds_init.maxLevelCell >= 1 +cell_patches, cell_patch_mask = plot_horiz_field( + ds=ds, ds_mesh=ds_mesh, field_name='ssh', out_file_name='plots/ssh.png', + vmin=-720, vmax=0, figsize=figsize, cell_mask=cell_mask) + +plot_horiz_field(ds=ds, ds_mesh=ds_mesh, field_name='bottomDepth', + out_file_name='plots/bottomDepth.png', vmin=0, vmax=720, + figsize=figsize, patches=cell_patches, + patch_mask=cell_patch_mask) + +edge_patches, edge_patch_mask = plot_horiz_field( + ds=ds, ds_mesh=ds_mesh, field_name='normalVelocity', + out_file_name='plots/normalVelocity.png', t_index=t_index, + vmin=-0.1, vmax=0.1, cmap='cmo.balance', cell_mask=cell_mask) + +... ``` (dev-visualization-global)= diff --git a/polaris/ocean/tasks/baroclinic_channel/init.py b/polaris/ocean/tasks/baroclinic_channel/init.py index 4f5290d82..983efb7ec 100644 --- a/polaris/ocean/tasks/baroclinic_channel/init.py +++ b/polaris/ocean/tasks/baroclinic_channel/init.py @@ -161,8 +161,9 @@ def run(self): write_netcdf(ds, 'initial_state.nc') + cell_mask = ds.maxLevelCell >= 1 plot_horiz_field(ds, ds_mesh, 'temperature', - 'initial_temperature.png') + 'initial_temperature.png', cell_mask=cell_mask) plot_horiz_field(ds, ds_mesh, 'normalVelocity', 'initial_normal_velocity.png', cmap='cmo.balance', - show_patch_edges=True) + show_patch_edges=True, cell_mask=cell_mask) diff --git a/polaris/ocean/tasks/baroclinic_channel/rpe/analysis.py b/polaris/ocean/tasks/baroclinic_channel/rpe/analysis.py index 7d56b3765..3d8c3ee98 100644 --- a/polaris/ocean/tasks/baroclinic_channel/rpe/analysis.py +++ b/polaris/ocean/tasks/baroclinic_channel/rpe/analysis.py @@ -100,14 +100,14 @@ def run(self): ax = axes[row_index] ds = xr.open_dataset(f'output_nu_{nu:g}.nc', decode_times=False) ds = ds.isel(nVertLevels=0) - ds['maxLevelCell'] = ds_init.maxLevelCell times = ds.daysSinceStartOfSim.values time_index = np.argmin(np.abs(times - time)) + cell_mask = ds_init.maxLevelCell >= 1 plot_horiz_field(ds, ds_mesh, 'temperature', ax=ax, cmap='cmo.thermal', t_index=time_index, vmin=min_temp, vmax=max_temp, - cmap_title='SST (C)') + cmap_title='SST (C)', cell_mask=cell_mask) ax.set_title(f'day {times[time_index]:g}, $\\nu_h=${nu:g}') plt.savefig(output_filename) diff --git a/polaris/ocean/tasks/baroclinic_channel/viz.py b/polaris/ocean/tasks/baroclinic_channel/viz.py index 4dc344619..aa2416c92 100644 --- a/polaris/ocean/tasks/baroclinic_channel/viz.py +++ b/polaris/ocean/tasks/baroclinic_channel/viz.py @@ -40,13 +40,15 @@ def run(self): ds_mesh = xr.load_dataset('mesh.nc') ds_init = xr.load_dataset('init.nc') ds = xr.load_dataset('output.nc') - ds['maxLevelCell'] = ds_init.maxLevelCell t_index = ds.sizes['Time'] - 1 + cell_mask = ds_init.maxLevelCell >= 1 plot_horiz_field(ds, ds_mesh, 'temperature', - 'final_temperature.png', t_index=t_index) + 'final_temperature.png', t_index=t_index, + cell_mask=cell_mask) max_velocity = np.max(np.abs(ds.normalVelocity.values)) plot_horiz_field(ds, ds_mesh, 'normalVelocity', 'final_normalVelocity.png', t_index=t_index, vmin=-max_velocity, vmax=max_velocity, - cmap='cmo.balance', show_patch_edges=True) + cmap='cmo.balance', show_patch_edges=True, + cell_mask=cell_mask) diff --git a/polaris/ocean/tasks/inertial_gravity_wave/viz.py b/polaris/ocean/tasks/inertial_gravity_wave/viz.py index 703b2bdc2..a146e1658 100644 --- a/polaris/ocean/tasks/inertial_gravity_wave/viz.py +++ b/polaris/ocean/tasks/inertial_gravity_wave/viz.py @@ -76,7 +76,6 @@ def run(self): ds_mesh = xr.open_dataset(f'mesh_{mesh_name}.nc') ds_init = xr.open_dataset(f'init_{mesh_name}.nc') ds = xr.open_dataset(f'output_{mesh_name}.nc') - ds['maxLevelCell'] = ds_init.maxLevelCell exact = ExactSolution(ds_init, config) t0 = datetime.datetime.strptime(ds.xtime.values[0].decode(), @@ -93,16 +92,20 @@ def run(self): if error_range is None: error_range = np.max(np.abs(ds.ssh_error.values)) - plot_horiz_field(ds, ds_mesh, 'ssh', ax=axes[i, 0], - cmap='cmo.balance', t_index=ds.sizes["Time"] - 1, - vmin=-eta0, vmax=eta0, cmap_title="SSH (m)") + cell_mask = ds_init.maxLevelCell >= 1 + patches, patch_mask = plot_horiz_field( + ds, ds_mesh, 'ssh', ax=axes[i, 0], cmap='cmo.balance', + t_index=ds.sizes["Time"] - 1, vmin=-eta0, vmax=eta0, + cmap_title="SSH (m)", cell_mask=cell_mask) plot_horiz_field(ds, ds_mesh, 'ssh_exact', ax=axes[i, 1], cmap='cmo.balance', - vmin=-eta0, vmax=eta0, cmap_title="SSH (m)") + vmin=-eta0, vmax=eta0, cmap_title="SSH (m)", + patches=patches, patch_mask=patch_mask) plot_horiz_field(ds, ds_mesh, 'ssh_error', ax=axes[i, 2], cmap='cmo.balance', cmap_title=r"$\Delta$ SSH (m)", - vmin=-error_range, vmax=error_range) + vmin=-error_range, vmax=error_range, + patches=patches, patch_mask=patch_mask) axes[0, 0].set_title('Numerical solution') axes[0, 1].set_title('Analytical solution') diff --git a/polaris/ocean/tasks/manufactured_solution/viz.py b/polaris/ocean/tasks/manufactured_solution/viz.py index 6541765a8..65a232f7b 100644 --- a/polaris/ocean/tasks/manufactured_solution/viz.py +++ b/polaris/ocean/tasks/manufactured_solution/viz.py @@ -76,7 +76,6 @@ def run(self): ds_mesh = xr.open_dataset(f'mesh_{mesh_name}.nc') ds_init = xr.open_dataset(f'init_{mesh_name}.nc') ds = xr.open_dataset(f'output_{mesh_name}.nc') - ds['maxLevelCell'] = ds_init.maxLevelCell exact = ExactSolution(config, ds_init) t0 = datetime.datetime.strptime(ds.xtime.values[0].decode(), @@ -93,15 +92,19 @@ def run(self): if error_range is None: error_range = np.max(np.abs(ds.ssh_error.values)) - plot_horiz_field(ds, ds_mesh, 'ssh', ax=axes[i, 0], - cmap='cmo.balance', t_index=ds.sizes["Time"] - 1, - vmin=-eta0, vmax=eta0, cmap_title="SSH") + cell_mask = ds_init.maxLevelCell >= 1 + patches, patch_mask = plot_horiz_field( + ds, ds_mesh, 'ssh', ax=axes[i, 0], cmap='cmo.balance', + t_index=ds.sizes["Time"] - 1, vmin=-eta0, vmax=eta0, + cmap_title="SSH", cell_mask=cell_mask) plot_horiz_field(ds, ds_mesh, 'ssh_exact', ax=axes[i, 1], cmap='cmo.balance', - vmin=-eta0, vmax=eta0, cmap_title="SSH") + vmin=-eta0, vmax=eta0, cmap_title="SSH", + patches=patches, patch_mask=patch_mask) plot_horiz_field(ds, ds_mesh, 'ssh_error', ax=axes[i, 2], cmap='cmo.balance', cmap_title="dSSH", - vmin=-error_range, vmax=error_range) + vmin=-error_range, vmax=error_range, + patches=patches, patch_mask=patch_mask) axes[0, 0].set_title('Numerical solution') axes[0, 1].set_title('Analytical solution') diff --git a/polaris/viz/planar.py b/polaris/viz/planar.py index d2b0b0925..fef04eda2 100644 --- a/polaris/viz/planar.py +++ b/polaris/viz/planar.py @@ -14,7 +14,9 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 ax=None, title=None, t_index=None, z_index=None, vmin=None, vmax=None, show_patch_edges=False, cmap=None, cmap_set_under=None, cmap_set_over=None, - cmap_scale='linear', cmap_title=None, figsize=None): + cmap_scale='linear', cmap_title=None, figsize=None, + vert_dim='nVertLevels', cell_mask=None, patches=None, + patch_mask=None): """ Plot a horizontal field from a planar domain using x,y coordinates at a single time and depth slice. @@ -27,13 +29,16 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 ds_mesh : xarray.Dataset A data set containing horizontal mesh variables - field_name: str + field_name : str The name of the variable to plot, which must be present in ds - out_file_name: str + out_file_name : str, optional The path to which the plot image should be written - title: str, optional + ax : matplotlib.axes.Axes + Axes to plot to if making a multi-panel figure + + title : str, optional The title of the plot vmin : float, optional @@ -65,12 +70,36 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 cmap_scale : {'log', 'linear'}, optional Whether the colormap is logarithmic or linear - cmap_title : str + cmap_title : str, optional Title for color bar - figsize : tuple + figsize : tuple, optional The width and height of the figure in inches. Default is determined based on the aspect ratio of the domain. + + vert_dim : str, optional + Name of the vertical dimension + + cell_mask : numpy.ndarray, optional + A ``bool`` mask indicating where cells are valid, used to mask fields + on both cells and edges. Not used if ``patches`` and ``patch_mask`` + are supplied + + patches : list of numpy.ndarray, optional + Patches from a previous call to ``plot_horiz_field()`` + + patch_mask : numpy.ndarray, optional + A mask of where the field has patches from a previous call to + ``plot_horiz_field()`` + + Returns + ------- + patches : list of numpy.ndarray + Patches to reuse for future plots. Patches for cells can only be + reused for other plots on cells and similarly for edges. + + patch_mask : numpy.ndarray + A mask used to select entries in the field that have patches """ use_mplstyle() @@ -89,9 +118,6 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 if title is None: title = field_name - if 'maxLevelCell' not in ds: - raise ValueError( - 'maxLevelCell must be added to ds before plotting.') if field_name not in ds: raise ValueError( f'{field_name} must be present in ds before plotting.') @@ -102,37 +128,48 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 t_index = 0 if t_index is not None: field = field.isel(Time=t_index) - if 'nVertLevels' in field.dims and z_index is None: + if vert_dim in field.dims and z_index is None: z_index = 0 if z_index is not None: - field = field.isel(nVertLevels=z_index) - - if 'nCells' in field.dims: - ocean_mask = ds.maxLevelCell - 1 >= 0 - ocean_patches, ocean_mask = _compute_cell_patches(ds_mesh, ocean_mask) - elif 'nEdges' in field.dims: - ocean_mask = np.ones_like(field, dtype='bool') - ocean_mask = _remove_boundary_edges_from_mask(ds_mesh, ocean_mask) - ocean_patches, ocean_mask = _compute_edge_patches(ds_mesh, ocean_mask) - ocean_patches.set_array(field[ocean_mask]) + field = field.isel({vert_dim: z_index}) + + if patches is not None: + if patch_mask is None: + raise ValueError('You must supply both patches and patch_mask ' + 'from a previous call to plot_horiz_field()') + else: + if cell_mask is None: + cell_mask = np.ones_like(field, type='bool') + if 'nCells' in field.dims: + patch_mask = cell_mask + patches, patch_mask = _compute_cell_patches(ds_mesh, patch_mask) + elif 'nEdges' in field.dims: + patch_mask = _edge_mask_from_cell_mask(ds_mesh, cell_mask) + patch_mask = _remove_boundary_edges_from_mask(ds_mesh, patch_mask) + patches, patch_mask = _compute_edge_patches(ds_mesh, patch_mask) + else: + raise ValueError('Cannot plot a field without dim nCells or ' + 'nEdges') + local_patches = PatchCollection(patches, alpha=1.) + local_patches.set_array(field[patch_mask]) if cmap is not None: - ocean_patches.set_cmap(cmap) + local_patches.set_cmap(cmap) if cmap_set_under is not None: - current_cmap = ocean_patches.get_cmap() + current_cmap = local_patches.get_cmap() current_cmap.set_under(cmap_set_under) if cmap_set_over is not None: - current_cmap = ocean_patches.get_cmap() + current_cmap = local_patches.get_cmap() current_cmap.set_over(cmap_set_over) if show_patch_edges: - ocean_patches.set_edgecolor('black') + local_patches.set_edgecolor('black') else: - ocean_patches.set_edgecolor('face') - ocean_patches.set_clim(vmin=vmin, vmax=vmax) + local_patches.set_edgecolor('face') + local_patches.set_clim(vmin=vmin, vmax=vmax) if cmap_scale == 'log': - ocean_patches.set_norm(LogNorm(vmin=max(1e-10, vmin), - vmax=vmax, clip=False)) + local_patches.set_norm(LogNorm(vmin=max(1e-10, vmin), + vmax=vmax, clip=False)) if figsize is None: width = ds_mesh.xCell.max() - ds_mesh.xCell.min() @@ -145,12 +182,12 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 if create_fig: plt.figure(figsize=figsize) ax = plt.subplot(111) - ax.add_collection(ocean_patches) + ax.add_collection(local_patches) ax.set_xlabel('x (km)') ax.set_ylabel('y (km)') ax.set_aspect('equal') ax.autoscale(tight=True) - cbar = plt.colorbar(ocean_patches, extend='both', shrink=0.7, ax=ax) + cbar = plt.colorbar(local_patches, extend='both', shrink=0.7, ax=ax) if cmap_title is not None: cbar.set_label(cmap_title) if create_fig: @@ -158,6 +195,19 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 plt.savefig(out_file_name, bbox_inches='tight', pad_inches=0.2) plt.close() + return patches, patch_mask + + +def _edge_mask_from_cell_mask(ds, cell_mask): + cells_on_edge = ds.cellsOnEdge - 1 + valid = cells_on_edge >= 0 + # the edge mask is True if either adjacent cell is valid and its mask is + # True + edge_mask = np.logical_or( + np.logical_and(valid[:, 0], cell_mask[cells_on_edge[:, 0]]), + np.logical_and(valid[:, 1], cell_mask[cells_on_edge[:, 1]])) + return edge_mask + def _remove_boundary_edges_from_mask(ds, mask): area_cell = ds.areaCell.values @@ -221,9 +271,7 @@ def _compute_cell_patches(ds, mask): polygon = Polygon(vertices, closed=True) patches.append(polygon) - p = PatchCollection(patches, alpha=1.) - - return p, mask + return patches, mask def _compute_edge_patches(ds, mask): @@ -252,6 +300,4 @@ def _compute_edge_patches(ds, mask): polygon = Polygon(vertices, closed=True) patches.append(polygon) - p = PatchCollection(patches, alpha=1.) - - return p, mask + return patches, mask