diff --git a/esmf_regrid/esmf_regridder.py b/esmf_regrid/esmf_regridder.py index 5322c415..a2641c7a 100644 --- a/esmf_regrid/esmf_regridder.py +++ b/esmf_regrid/esmf_regridder.py @@ -16,21 +16,29 @@ ] -def _get_regrid_weights_dict(src_field, tgt_field, regrid_method): +def _get_regrid_weights_dict(src_field, tgt_field, regrid_method, esmf_args=None): + if esmf_args is None: + esmf_args = {} + else: + esmf_args = esmf_args.copy() + # Provide default values + if "ignore_degenerate" not in esmf_args: + esmf_args["ignore_degenerate"] = True + if "unmapped_action" not in esmf_args: + esmf_args["unmapped_action"] = esmpy.UnmappedAction.IGNORE # The value, in array form, that ESMF should treat as an affirmative mask. expected_mask = np.array([True]) regridder = esmpy.Regrid( src_field, tgt_field, - ignore_degenerate=True, regrid_method=regrid_method, - unmapped_action=esmpy.UnmappedAction.IGNORE, # Choosing the norm_type DSTAREA allows for mdtol type operations # to be performed using the weights information later on. norm_type=esmpy.NormType.DSTAREA, src_mask_values=expected_mask, dst_mask_values=expected_mask, factors=True, + **esmf_args, ) # Without specifying deep_copy=true, the information in weights_dict # would be corrupted when the ESMF regridder is destoyed. @@ -59,7 +67,12 @@ class Regridder: """Regridder for directly interfacing with :mod:`esmpy`.""" def __init__( - self, src, tgt, method=Constants.Method.CONSERVATIVE, precomputed_weights=None + self, + src, + tgt, + method=Constants.Method.CONSERVATIVE, + esmf_args=None, + precomputed_weights=None, ): """ Create a regridder from descriptions of horizontal grids/meshes. @@ -85,6 +98,8 @@ def __init__( If ``None``, :mod:`esmpy` will be used to calculate regridding weights. Otherwise, :mod:`esmpy` will be bypassed and ``precomputed_weights`` will be used as the regridding weights. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. """ self.src = src self.tgt = tgt @@ -98,6 +113,7 @@ def __init__( src.make_esmf_field(), tgt.make_esmf_field(), regrid_method=method.value, + esmf_args=esmf_args, ) self.weight_matrix = _weights_dict_to_sparse_array( weights_dict, diff --git a/esmf_regrid/experimental/io.py b/esmf_regrid/experimental/io.py index b43a6619..2c285bde 100644 --- a/esmf_regrid/experimental/io.py +++ b/esmf_regrid/experimental/io.py @@ -9,8 +9,7 @@ import scipy.sparse import esmf_regrid -from esmf_regrid import _load_context -from esmf_regrid import check_method, Constants +from esmf_regrid import _load_context, check_method, Constants, esmpy from esmf_regrid.experimental.unstructured_scheme import ( GridToMeshESMFRegridder, MeshToGridESMFRegridder, @@ -48,6 +47,29 @@ RESOLUTION = "resolution" SOURCE_RESOLUTION = "src_resolution" TARGET_RESOLUTION = "tgt_resolution" +ESMF_ARGS = "esmf_args" +VALID_ESMF_KWARGS = [ + "pole_method", + "regrid_pole_npoints", + "line_type", + "extrap_method", + "extrap_num_src_pnts", + "extrap_dist_exponent", + "extrap_num_levels", + "unmapped_action", + "ignore_degenerate", + "large_file", +] +POLE_METHOD_DICT = {e.name: e for e in esmpy.PoleMethod} +LINE_TYPE_DICT = {e.name: e for e in esmpy.LineType} +EXTRAP_METHOD_DICT = {e.name: e for e in esmpy.ExtrapMethod} +UNMAPPED_ACTION_DICT = {e.name: e for e in esmpy.UnmappedAction} +ESMF_ENUM_ARGS = { + "pole_method": POLE_METHOD_DICT, + "line_type": LINE_TYPE_DICT, + "extrap_method": EXTRAP_METHOD_DICT, + "unmapped_action": UNMAPPED_ACTION_DICT, +} def _add_mask_to_cube(mask, cube, name): @@ -254,6 +276,21 @@ def _standard_mesh_cube(mesh, location, name): weights_cube.add_aux_coord(row_coord, 0) weights_cube.add_aux_coord(col_coord, 0) + esmf_args = rg.esmf_args + if esmf_args is None: + esmf_args = {} + for arg in esmf_args.keys(): + if arg not in VALID_ESMF_KWARGS: + raise KeyError(f"{arg} is not considered a valid argument to pass to ESMF.") + esmf_arg_attributes = { + k: v.name if hasattr(v, "name") else int(v) if isinstance(v, bool) else v + for k, v in esmf_args.items() + } + esmf_arg_coord = AuxCoord( + 0, var_name=ESMF_ARGS, long_name=ESMF_ARGS, attributes=esmf_arg_attributes + ) + weights_cube.add_aux_coord(esmf_arg_coord) + weight_shape_cube = Cube( weight_shape, var_name=WEIGHTS_SHAPE_NAME, @@ -342,6 +379,11 @@ def load_regridder(filename): else: use_tgt_mask = False + esmf_args = weights_cube.coord(ESMF_ARGS).attributes + for arg, dict in ESMF_ENUM_ARGS.items(): + if arg in esmf_args: + esmf_args[arg] = dict[esmf_args[arg]] + if scheme is GridToMeshESMFRegridder: resolution_keyword = SOURCE_RESOLUTION kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} @@ -365,6 +407,7 @@ def load_regridder(filename): precomputed_weights=weight_matrix, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, + esmf_args=esmf_args, **kwargs, ) diff --git a/esmf_regrid/experimental/unstructured_scheme.py b/esmf_regrid/experimental/unstructured_scheme.py index 172bfd4e..46f445cb 100644 --- a/esmf_regrid/experimental/unstructured_scheme.py +++ b/esmf_regrid/experimental/unstructured_scheme.py @@ -119,6 +119,7 @@ def __init__( tgt_resolution=None, use_src_mask=False, use_tgt_mask=False, + esmf_args=None, ): """ Create regridder for conversions between source mesh and target grid. @@ -157,6 +158,8 @@ def __init__( a boolean value. If True, this array is taken from the mask on the data in ``tgt``. If False, no mask will be taken and all points will be used in weights calculation. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Raises ------ @@ -177,6 +180,7 @@ def __init__( tgt_resolution=tgt_resolution, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, + esmf_args=esmf_args, ) self.resolution = tgt_resolution self.mesh, self.location = self._src @@ -286,6 +290,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location=None, + esmf_args=None, ): """ Create regridder for conversions between source grid and target mesh. @@ -328,6 +333,8 @@ def __init__( tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Raises ------ @@ -348,6 +355,7 @@ def __init__( use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, + esmf_args=esmf_args, ) self.resolution = src_resolution self.mesh, self.location = self._tgt diff --git a/esmf_regrid/schemes.py b/esmf_regrid/schemes.py index 69ba70e1..d768038d 100644 --- a/esmf_regrid/schemes.py +++ b/esmf_regrid/schemes.py @@ -19,7 +19,7 @@ except ImportError: raise exc -from esmf_regrid import check_method, Constants +from esmf_regrid import check_method, Constants, esmpy from esmf_regrid.esmf_regridder import GridInfo, RefinedGridInfo, Regridder from esmf_regrid.experimental.unstructured_regrid import MeshInfo @@ -33,6 +33,12 @@ "regrid_rectilinear_to_rectilinear", ] +STANDAR_BILINEAR_EXTAP_ARGS = { + "extrap_method": esmpy.ExtrapMethod.NEAREST_IDAVG, + "extrap_num_src_pnts": 2, + "extrap_dist_exponent": 1, +} + def _get_coord(cube, axis): try: @@ -292,6 +298,42 @@ def _regrid_along_dims(data, regridder, dims, num_out_dims, mdtol): return result +def _check_esmf_args(kwargs): + if kwargs is not None: + if not isinstance(kwargs, dict): + raise TypeError(f"Expected `esmf_args` to be a dict, got a {type(kwargs)}.") + invalid_kwargs = [ + "filename", + "norm_type", + "rh_filename", + "regrid_method", + "src_mask_values", + "dst_mask_values", + "factors", + "src_frac_field", + "dst_frac_field", + ] + valid_kwargs = [ + "pole method", + "regrid_pole_npoints", + "line_type", + "extrap_method", + "extrap_num_src_pnts", + "extrap_dist_exponent", + "extrap_num_levels", + "unmapped_action", + "ignore_degenerate", + "large_file", + ] + for kwarg in kwargs.keys(): + if kwarg in invalid_kwargs: + msg = f"`esmpy.Regrid` argument `{kwarg}` cannot be controlled by `esmf_args`." + raise ValueError(msg) + if kwarg not in valid_kwargs: + msg = f"`{kwarg}` is not a valid argument for `esmpy.Regrid`." + raise ValueError(msg) + + def _map_complete_blocks( src, func, active_dims, out_sizes, *args, dtype=None, **kwargs ): @@ -534,6 +576,7 @@ def _regrid_rectilinear_to_rectilinear__prepare( tgt_resolution=None, src_mask=None, tgt_mask=None, + esmf_args=None, ): """ First (setup) part of 'regrid_rectilinear_to_rectilinear'. @@ -556,8 +599,14 @@ def _regrid_rectilinear_to_rectilinear__prepare( srcinfo = _make_gridinfo(src_grid_cube, method, src_resolution, src_mask) tgtinfo = _make_gridinfo(tgt_grid_cube, method, tgt_resolution, tgt_mask) + _check_esmf_args(esmf_args) + regridder = Regridder( - srcinfo, tgtinfo, method=method, precomputed_weights=precomputed_weights + srcinfo, + tgtinfo, + method=method, + esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) regrid_info = RegridInfo( @@ -620,6 +669,7 @@ def _regrid_unstructured_to_rectilinear__prepare( tgt_resolution=None, src_mask=None, tgt_mask=None, + esmf_args=None, ): """ First (setup) part of 'regrid_unstructured_to_rectilinear'. @@ -638,8 +688,14 @@ def _regrid_unstructured_to_rectilinear__prepare( meshinfo = _make_meshinfo(src_mesh_cube, method, src_mask, "source") gridinfo = _make_gridinfo(target_grid_cube, method, tgt_resolution, tgt_mask) + _check_esmf_args(esmf_args) + regridder = Regridder( - meshinfo, gridinfo, method=method, precomputed_weights=precomputed_weights + meshinfo, + gridinfo, + method=method, + esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) regrid_info = RegridInfo( @@ -706,6 +762,7 @@ def _regrid_rectilinear_to_unstructured__prepare( src_mask=None, tgt_mask=None, tgt_location=None, + esmf_args=None, ): """ First (setup) part of 'regrid_rectilinear_to_unstructured'. @@ -734,8 +791,14 @@ def _regrid_rectilinear_to_unstructured__prepare( ) gridinfo = _make_gridinfo(src_grid_cube, method, src_resolution, src_mask) + _check_esmf_args(esmf_args) + regridder = Regridder( - gridinfo, meshinfo, method=method, precomputed_weights=precomputed_weights + gridinfo, + meshinfo, + method=method, + esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) regrid_info = RegridInfo( @@ -804,6 +867,7 @@ def _regrid_unstructured_to_unstructured__prepare( tgt_mask=None, src_location=None, tgt_location=None, + esmf_args=None, ): """ First (setup) part of 'regrid_unstructured_to_unstructured'. @@ -833,6 +897,7 @@ def _regrid_unstructured_to_unstructured__prepare( tgt_meshinfo, method=method, precomputed_weights=precomputed_weights, + esmf_args=esmf_args, ) regrid_info = RegridInfo( @@ -955,7 +1020,12 @@ class ESMFAreaWeighted: """ def __init__( - self, mdtol=0, use_src_mask=False, use_tgt_mask=False, tgt_location="face" + self, + mdtol=0, + use_src_mask=False, + use_tgt_mask=False, + tgt_location="face", + esmf_args={}, ): """ Area-weighted scheme for regridding between rectilinear grids. @@ -979,6 +1049,8 @@ def __init__( tgt_location : str or None, default="face" Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. """ if not (0 <= mdtol <= 1): @@ -992,10 +1064,20 @@ def __init__( self.use_src_mask = use_src_mask self.use_tgt_mask = use_tgt_mask self.tgt_location = "face" + _check_esmf_args(esmf_args) + self.esmf_args = esmf_args def __repr__(self): """Return a representation of the class.""" - return "ESMFAreaWeighted(mdtol={})".format(self.mdtol) + result = ( + f"ESMFAreaWeighted(" + f" mdtol={self.mdtol}," + f" use_src_mask={self.use_src_mask}," + f" use_tgt_mask={self.use_tgt_mask}," + f" esmf_args={self.esmf_args}," + f")" + ) + return result def regridder( self, @@ -1006,6 +1088,7 @@ def regridder( use_src_mask=None, use_tgt_mask=None, tgt_location="face", + esmf_args=None, ): """ Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1035,6 +1118,8 @@ def regridder( tgt_location : str or None, default="face" Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Returns @@ -1055,6 +1140,9 @@ def regridder( use_src_mask = self.use_src_mask if use_tgt_mask is None: use_tgt_mask = self.use_tgt_mask + if esmf_args is None: + esmf_args = self.esmf_args + if tgt_location is not None and tgt_location != "face": raise ValueError( "For area weighted regridding, target location must be 'face'." @@ -1068,6 +1156,7 @@ def regridder( use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location="face", + esmf_args=esmf_args, ) @@ -1081,7 +1170,13 @@ class ESMFBilinear: """ def __init__( - self, mdtol=0, use_src_mask=False, use_tgt_mask=False, tgt_location=None + self, + mdtol=0, + use_src_mask=False, + use_tgt_mask=False, + tgt_location=None, + extrapolate_gaps=False, + esmf_args={}, ): """ Area-weighted scheme for regridding between rectilinear grids. @@ -1101,6 +1196,12 @@ def __init__( tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + extrapolate_gaps : bool, default=False + Use a standard set of ESMF arguments for extrapolation which achieves + continuity with bilinear regridding. Useful for situations where gaps + between cells would be masked. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. """ if not (0 <= mdtol <= 1): @@ -1110,10 +1211,22 @@ def __init__( self.use_src_mask = use_src_mask self.use_tgt_mask = use_tgt_mask self.tgt_location = tgt_location + if extrapolate_gaps: + esmf_args = STANDAR_BILINEAR_EXTAP_ARGS + _check_esmf_args(esmf_args) + self.esmf_args = esmf_args def __repr__(self): """Return a representation of the class.""" - return "ESMFBilinear(mdtol={})".format(self.mdtol) + result = ( + f"ESMFBilinear(" + f" mdtol={self.mdtol}," + f" use_src_mask={self.use_src_mask}," + f" use_tgt_mask={self.use_tgt_mask}," + f" esmf_args={self.esmf_args}," + f")" + ) + return result def regridder( self, @@ -1122,6 +1235,8 @@ def regridder( use_src_mask=None, use_tgt_mask=None, tgt_location=None, + extrapolate_gaps=False, + esmf_args=None, ): """ Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1142,6 +1257,12 @@ def regridder( tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + extrapolate_gaps : bool, default=False + Use a standard set of ESMF arguments for extrapolation which achieves + continuity with bilinear regridding. Useful for situations where gaps + between cells would be masked. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Returns ------- @@ -1163,6 +1284,10 @@ def regridder( use_tgt_mask = self.use_tgt_mask if tgt_location is None: tgt_location = self.tgt_location + if esmf_args is None: + esmf_args = self.esmf_args + if extrapolate_gaps: + esmf_args = STANDAR_BILINEAR_EXTAP_ARGS return ESMFBilinearRegridder( src_grid, tgt_grid, @@ -1170,6 +1295,7 @@ def regridder( use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, + esmf_args=esmf_args, ) @@ -1199,7 +1325,9 @@ class ESMFNearest: the same equivalent space will behave the same. """ - def __init__(self, use_src_mask=False, use_tgt_mask=False, tgt_location=None): + def __init__( + self, use_src_mask=False, use_tgt_mask=False, tgt_location=None, esmf_args={} + ): """ Nearest neighbour scheme for regridding between rectilinear grids. @@ -1214,14 +1342,25 @@ def __init__(self, use_src_mask=False, use_tgt_mask=False, tgt_location=None): tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. """ self.use_src_mask = use_src_mask self.use_tgt_mask = use_tgt_mask self.tgt_location = tgt_location + _check_esmf_args(esmf_args) + self.esmf_args = esmf_args def __repr__(self): """Return a representation of the class.""" - return "ESMFNearest()" + result = ( + f"ESMFNearest(" + f" use_src_mask={self.use_src_mask}," + f" use_tgt_mask={self.use_tgt_mask}," + f" esmf_args={self.esmf_args}," + f")" + ) + return result def regridder( self, @@ -1230,6 +1369,7 @@ def regridder( use_src_mask=None, use_tgt_mask=None, tgt_location=None, + esmf_args=None, ): """ Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1250,6 +1390,8 @@ def regridder( tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Returns ------- @@ -1271,12 +1413,15 @@ def regridder( use_tgt_mask = self.use_tgt_mask if tgt_location is None: tgt_location = self.tgt_location + if esmf_args is None: + esmf_args = self.esmf_args return ESMFNearestRegridder( src_grid, tgt_grid, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, + esmf_args=esmf_args, ) @@ -1292,6 +1437,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location=None, + esmf_args={}, **kwargs, ): """ @@ -1340,6 +1486,9 @@ def __init__( self.mdtol = mdtol self.method = method + self.esmf_args = esmf_args + kwargs["esmf_args"] = self.esmf_args + self.src_mask = _get_mask(src, use_src_mask) kwargs["src_mask"] = self.src_mask self.tgt_mask = _get_mask(tgt, use_tgt_mask) @@ -1470,6 +1619,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location="face", + esmf_args={}, ): """ Create regridder for conversions between ``src`` and ``tgt``. @@ -1508,6 +1658,8 @@ def __init__( tgt_location : str or None, default="face" Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Raises ------ @@ -1536,6 +1688,7 @@ def __init__( mdtol=mdtol, precomputed_weights=precomputed_weights, tgt_location="face", + esmf_args=esmf_args, **kwargs, ) @@ -1552,6 +1705,8 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location=None, + extrapolate_gaps=False, + esmf_args={}, ): """ Create regridder for conversions between ``src`` and ``tgt``. @@ -1581,6 +1736,12 @@ def __init__( tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + extrapolate_gaps : bool, default=False + Use a standard set of ESMF arguments for extrapolation which achieves + continuity with bilinear regridding. Useful for situations where gaps + between cells would be masked. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Raises ------ @@ -1588,6 +1749,8 @@ def __init__( If ``use_src_mask`` or ``use_tgt_mask`` are True while the masks on ``src`` or ``tgt`` respectively are not constant over non-horizontal dimensions. """ + if extrapolate_gaps: + esmf_args = STANDAR_BILINEAR_EXTAP_ARGS super().__init__( src, tgt, @@ -1597,6 +1760,7 @@ def __init__( use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, + esmf_args=esmf_args, ) @@ -1611,6 +1775,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location=None, + esmf_args={}, ): """ Create regridder for conversions between ``src`` and ``tgt``. @@ -1634,6 +1799,8 @@ def __init__( tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Raises ------ @@ -1650,4 +1817,5 @@ def __init__( use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, + esmf_args=esmf_args, ) diff --git a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py index a53f034f..30d594a7 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py +++ b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py @@ -10,6 +10,7 @@ ESMFAreaWeightedRegridder, ESMFBilinear, ESMFNearest, + esmpy, ) from esmf_regrid.experimental.io import load_regridder, save_regridder from esmf_regrid.experimental.unstructured_scheme import ( @@ -456,7 +457,14 @@ def test_generic_regridder(tmp_path, src_type, tgt_type, scheme): elif tgt_type == "mesh": tgt = _gridlike_mesh_cube(n_lons_tgt, n_lats_tgt) - original_rg = scheme().regridder(src, tgt) + esmf_args = { + "unmapped_action": esmpy.UnmappedAction.ERROR, + "ignore_degenerate": False, + "line_type": esmpy.LineType.CART, + "large_file": True, + } + + original_rg = scheme().regridder(src, tgt, esmf_args=esmf_args) filename = tmp_path / "regridder.nc" save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) @@ -469,6 +477,7 @@ def test_generic_regridder(tmp_path, src_type, tgt_type, scheme): assert original_rg.src_resolution == loaded_rg.src_resolution assert original_rg.tgt_resolution == loaded_rg.tgt_resolution assert original_rg.mdtol == loaded_rg.mdtol + assert original_rg.esmf_args == loaded_rg.esmf_args @pytest.mark.parametrize( diff --git a/esmf_regrid/tests/unit/schemes/__init__.py b/esmf_regrid/tests/unit/schemes/__init__.py index c51deffd..92ab695c 100644 --- a/esmf_regrid/tests/unit/schemes/__init__.py +++ b/esmf_regrid/tests/unit/schemes/__init__.py @@ -6,6 +6,7 @@ from numpy import ma import pytest +from esmf_regrid import esmpy from esmf_regrid.schemes import ESMFAreaWeighted, ESMFBilinear, ESMFNearest from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import _grid_cube from esmf_regrid.tests.unit.schemes.test__mesh_to_MeshInfo import ( @@ -252,3 +253,51 @@ def _test_dtype_handling(scheme, src_type, tgt_type, in_dtype): assert result.lazy_data().dtype == expected_dtype assert result.data.dtype == expected_dtype + + +def _test_esmf_args(scheme): + """Test regridding scheme handles esmf_args as expected.""" + n_lons_src = 6 + n_lons_tgt = 3 + n_lats_src = 4 + n_lats_tgt = 2 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + + src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True) + tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True) + + valid_esmf_args = { + "unmapped_action": esmpy.UnmappedAction.ERROR, + "ignore_degenerate": False, + "line_type": esmpy.LineType.CART, + "large_file": True, + } + + rg_1 = scheme(esmf_args=valid_esmf_args).regridder(src, tgt) + rg_2 = scheme().regridder(src, tgt, esmf_args=valid_esmf_args) + + assert rg_1.esmf_args == valid_esmf_args + assert rg_2.esmf_args == valid_esmf_args + + invalid_esmf_args_duplicate = {"regrid_method": None} + invalid_esmf_args_incorrect = {"invalid_arg": None} + invalid_esmf_args_type = "invalid_arg" + + match_duplicate = "cannot be controlled by `esmf_args`" + with pytest.raises(ValueError, match=match_duplicate): + _ = scheme(esmf_args=invalid_esmf_args_duplicate) + with pytest.raises(ValueError, match=match_duplicate): + _ = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_duplicate) + + match_incorrect = "is not a valid argument for" + with pytest.raises(ValueError, match=match_incorrect): + _ = scheme(esmf_args=invalid_esmf_args_incorrect) + with pytest.raises(ValueError, match=match_incorrect): + _ = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_incorrect) + + match_type = "Expected `esmf_args` to be a dict, got a " + with pytest.raises(TypeError, match=match_type): + _ = scheme(esmf_args=invalid_esmf_args_type) + with pytest.raises(TypeError, match=match_type): + _ = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_type) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py index 50af91b3..f48ceae9 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py @@ -6,6 +6,7 @@ from esmf_regrid.tests.unit.schemes.__init__ import ( _test_cube_regrid, _test_dtype_handling, + _test_esmf_args, _test_invalid_mdtol, _test_mask_from_init, _test_mask_from_regridder, @@ -81,3 +82,8 @@ def test_dtype_handling(src_tgt_types, in_dtype): """Test regridding scheme handles dtype as expected.""" src_type, tgt_type = src_tgt_types _test_dtype_handling(ESMFAreaWeighted, src_type, tgt_type, in_dtype) + + +def test_esmf_args(): + """Test regridding scheme handles esmf_args as expected.""" + _test_esmf_args(ESMFAreaWeighted) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py b/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py index f6a9e30f..8b4a5a1f 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py @@ -6,6 +6,7 @@ from esmf_regrid.tests.unit.schemes.__init__ import ( _test_cube_regrid, _test_dtype_handling, + _test_esmf_args, _test_invalid_mdtol, _test_mask_from_init, _test_mask_from_regridder, @@ -70,3 +71,8 @@ def test_dtype_handling(src_tgt_types, in_dtype): """Test regridding scheme handles dtype as expected.""" src_type, tgt_type = src_tgt_types _test_dtype_handling(ESMFBilinear, src_type, tgt_type, in_dtype) + + +def test_esmf_args(): + """Test regridding scheme handles esmf_args as expected.""" + _test_esmf_args(ESMFBilinear) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py b/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py index da60f416..0a39172d 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from esmf_regrid import esmpy from esmf_regrid.schemes import ESMFBilinearRegridder from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import ( _curvilinear_cube, @@ -329,3 +330,33 @@ def test_regrid_data(): ) result = rg(src) np.testing.assert_allclose(expected_data, result.data) + + +def test_extrapolate_gaps(): + """Test that `extrapolate_gaps` argument works as expected.""" + n_lons = 6 + n_lats = 5 + src_lon_bounds = (-140, 180) + tgt_lon_bounds = (-180, 180) + src_lat_bounds = (-80, 90) + tgt_lat_bounds = (-90, 90) + src = _grid_cube(n_lons, n_lats, src_lon_bounds, src_lat_bounds, circular=False) + src.data = np.arange(n_lons * n_lats).reshape(n_lats, n_lons) + tgt = _grid_cube(n_lons, n_lats, tgt_lon_bounds, tgt_lat_bounds, circular=True) + + extrapolate_regridder = ESMFBilinearRegridder(src, tgt, extrapolate_gaps=True) + normal_regridder = ESMFBilinearRegridder(src, tgt, extrapolate_gaps=False) + + extrapolate_result = extrapolate_regridder(src) + normal_result = normal_regridder(src) + + assert not np.ma.is_masked(extrapolate_result.data) + assert np.ma.is_masked(normal_result.data) + + expected_args = { + "extrap_method": esmpy.ExtrapMethod.NEAREST_IDAVG, + "extrap_num_src_pnts": 2, + "extrap_dist_exponent": 1, + } + assert extrapolate_regridder.esmf_args == expected_args + assert normal_regridder.esmf_args == {} diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py b/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py index 6f146e6c..db90d830 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py @@ -7,6 +7,7 @@ from esmf_regrid.schemes import ESMFNearest from esmf_regrid.tests.unit.schemes.__init__ import ( _test_dtype_handling, + _test_esmf_args, _test_mask_from_init, _test_mask_from_regridder, _test_non_degree_crs, @@ -113,3 +114,8 @@ def test_dtype_handling(src_tgt_types, in_dtype): """Test regridding scheme handles dtype as expected.""" src_type, tgt_type = src_tgt_types _test_dtype_handling(ESMFNearest, src_type, tgt_type, in_dtype) + + +def test_esmf_args(): + """Test regridding scheme handles esmf_args as expected.""" + _test_esmf_args(ESMFNearest)