Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ESMF arguments to be passed #396

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,29 @@
]


def _get_regrid_weights_dict(src_field, tgt_field, regrid_method):
def _get_regrid_weights_dict(src_field, tgt_field, regrid_method, esmf_args=None):
if esmf_args is None:
esmf_args = {}
else:
esmf_args = esmf_args.copy()
# Provide default values
if "ignore_degenerate" not in esmf_args:
esmf_args["ignore_degenerate"] = True
if "unmapped_action" not in esmf_args:
esmf_args["unmapped_action"] = esmpy.UnmappedAction.IGNORE
# The value, in array form, that ESMF should treat as an affirmative mask.
expected_mask = np.array([True])
regridder = esmpy.Regrid(
src_field,
tgt_field,
ignore_degenerate=True,
regrid_method=regrid_method,
unmapped_action=esmpy.UnmappedAction.IGNORE,
# Choosing the norm_type DSTAREA allows for mdtol type operations
# to be performed using the weights information later on.
norm_type=esmpy.NormType.DSTAREA,
src_mask_values=expected_mask,
dst_mask_values=expected_mask,
factors=True,
**esmf_args,
)
# Without specifying deep_copy=true, the information in weights_dict
# would be corrupted when the ESMF regridder is destoyed.
Expand Down Expand Up @@ -59,7 +67,12 @@ class Regridder:
"""Regridder for directly interfacing with :mod:`esmpy`."""

def __init__(
self, src, tgt, method=Constants.Method.CONSERVATIVE, precomputed_weights=None
self,
src,
tgt,
method=Constants.Method.CONSERVATIVE,
esmf_args=None,
precomputed_weights=None,
):
"""
Create a regridder from descriptions of horizontal grids/meshes.
Expand All @@ -85,6 +98,8 @@ def __init__(
If ``None``, :mod:`esmpy` will be used to
calculate regridding weights. Otherwise, :mod:`esmpy` will be bypassed
and ``precomputed_weights`` will be used as the regridding weights.
esmf_args : dict, optional
A dictionary of arguments to pass to ESMF.
"""
self.src = src
self.tgt = tgt
Expand All @@ -98,6 +113,7 @@ def __init__(
src.make_esmf_field(),
tgt.make_esmf_field(),
regrid_method=method.value,
esmf_args=esmf_args,
)
self.weight_matrix = _weights_dict_to_sparse_array(
weights_dict,
Expand Down
47 changes: 45 additions & 2 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import scipy.sparse

import esmf_regrid
from esmf_regrid import _load_context
from esmf_regrid import check_method, Constants
from esmf_regrid import _load_context, check_method, Constants, esmpy
from esmf_regrid.experimental.unstructured_scheme import (
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
Expand Down Expand Up @@ -48,6 +47,29 @@
RESOLUTION = "resolution"
SOURCE_RESOLUTION = "src_resolution"
TARGET_RESOLUTION = "tgt_resolution"
ESMF_ARGS = "esmf_args"
VALID_ESMF_KWARGS = [
"pole_method",
"regrid_pole_npoints",
"line_type",
"extrap_method",
"extrap_num_src_pnts",
"extrap_dist_exponent",
"extrap_num_levels",
"unmapped_action",
"ignore_degenerate",
"large_file",
]
POLE_METHOD_DICT = {e.name: e for e in esmpy.PoleMethod}
LINE_TYPE_DICT = {e.name: e for e in esmpy.LineType}
EXTRAP_METHOD_DICT = {e.name: e for e in esmpy.ExtrapMethod}
UNMAPPED_ACTION_DICT = {e.name: e for e in esmpy.UnmappedAction}
ESMF_ENUM_ARGS = {
"pole_method": POLE_METHOD_DICT,
"line_type": LINE_TYPE_DICT,
"extrap_method": EXTRAP_METHOD_DICT,
"unmapped_action": UNMAPPED_ACTION_DICT,
}


def _add_mask_to_cube(mask, cube, name):
Expand Down Expand Up @@ -254,6 +276,21 @@ def _standard_mesh_cube(mesh, location, name):
weights_cube.add_aux_coord(row_coord, 0)
weights_cube.add_aux_coord(col_coord, 0)

esmf_args = rg.esmf_args
if esmf_args is None:
esmf_args = {}
for arg in esmf_args.keys():
if arg not in VALID_ESMF_KWARGS:
raise KeyError(f"{arg} is not considered a valid argument to pass to ESMF.")
esmf_arg_attributes = {
k: v.name if hasattr(v, "name") else int(v) if isinstance(v, bool) else v
for k, v in esmf_args.items()
}
esmf_arg_coord = AuxCoord(
0, var_name=ESMF_ARGS, long_name=ESMF_ARGS, attributes=esmf_arg_attributes
)
weights_cube.add_aux_coord(esmf_arg_coord)

weight_shape_cube = Cube(
weight_shape,
var_name=WEIGHTS_SHAPE_NAME,
Expand Down Expand Up @@ -342,6 +379,11 @@ def load_regridder(filename):
else:
use_tgt_mask = False

esmf_args = weights_cube.coord(ESMF_ARGS).attributes
for arg, dict in ESMF_ENUM_ARGS.items():
if arg in esmf_args:
esmf_args[arg] = dict[esmf_args[arg]]

if scheme is GridToMeshESMFRegridder:
resolution_keyword = SOURCE_RESOLUTION
kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol}
Expand All @@ -365,6 +407,7 @@ def load_regridder(filename):
precomputed_weights=weight_matrix,
use_src_mask=use_src_mask,
use_tgt_mask=use_tgt_mask,
esmf_args=esmf_args,
**kwargs,
)

Expand Down
8 changes: 8 additions & 0 deletions esmf_regrid/experimental/unstructured_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
tgt_resolution=None,
use_src_mask=False,
use_tgt_mask=False,
esmf_args=None,
):
"""
Create regridder for conversions between source mesh and target grid.
Expand Down Expand Up @@ -157,6 +158,8 @@ def __init__(
a boolean value. If True, this array is taken from the mask on the data
in ``tgt``. If False, no mask will be taken and all points
will be used in weights calculation.
esmf_args : dict, optional
A dictionary of arguments to pass to ESMF.

Raises
------
Expand All @@ -177,6 +180,7 @@ def __init__(
tgt_resolution=tgt_resolution,
use_src_mask=use_src_mask,
use_tgt_mask=use_tgt_mask,
esmf_args=esmf_args,
)
self.resolution = tgt_resolution
self.mesh, self.location = self._src
Expand Down Expand Up @@ -286,6 +290,7 @@ def __init__(
use_src_mask=False,
use_tgt_mask=False,
tgt_location=None,
esmf_args=None,
):
"""
Create regridder for conversions between source grid and target mesh.
Expand Down Expand Up @@ -328,6 +333,8 @@ def __init__(
tgt_location : str or None, default=None
Either "face" or "node". Describes the location for data on the mesh
if the target is not a :class:`~iris.cube.Cube`.
esmf_args : dict, optional
A dictionary of arguments to pass to ESMF.

Raises
------
Expand All @@ -348,6 +355,7 @@ def __init__(
use_src_mask=use_src_mask,
use_tgt_mask=use_tgt_mask,
tgt_location=tgt_location,
esmf_args=esmf_args,
)
self.resolution = src_resolution
self.mesh, self.location = self._tgt
Expand Down
Loading