From 938575c9c9f5589e503de58c62a8a7119a9e9b73 Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Wed, 21 Aug 2024 09:28:32 +0100 Subject: [PATCH 1/6] allow ESMF arguments to be passed --- esmf_regrid/__init__.py | 2 +- esmf_regrid/esmf_regridder.py | 15 +- esmf_regrid/experimental/io.py | 49 +++++ .../experimental/unstructured_scheme.py | 8 + esmf_regrid/schemes.py | 189 +++++++++++++++++- .../experimental/io/test_round_tripping.py | 9 +- 6 files changed, 258 insertions(+), 14 deletions(-) diff --git a/esmf_regrid/__init__.py b/esmf_regrid/__init__.py index 9c6070d6..2472b508 100644 --- a/esmf_regrid/__init__.py +++ b/esmf_regrid/__init__.py @@ -16,7 +16,7 @@ raise exc if hasattr(_imesh, "PARSE_UGRID_ON_LOAD"): - _load_context = _imesh.PARSE_UGRID_ON_LOAD + _load_context = _imesh.PARSE_UGRID_ON_LOAD.context else: from contextlib import nullcontext diff --git a/esmf_regrid/esmf_regridder.py b/esmf_regrid/esmf_regridder.py index 5322c415..af136099 100644 --- a/esmf_regrid/esmf_regridder.py +++ b/esmf_regrid/esmf_regridder.py @@ -16,7 +16,9 @@ ] -def _get_regrid_weights_dict(src_field, tgt_field, regrid_method): +def _get_regrid_weights_dict(src_field, tgt_field, regrid_method, esmf_args=None): + if esmf_args is None: + esmf_args = {} # The value, in array form, that ESMF should treat as an affirmative mask. expected_mask = np.array([True]) regridder = esmpy.Regrid( @@ -31,6 +33,7 @@ def _get_regrid_weights_dict(src_field, tgt_field, regrid_method): src_mask_values=expected_mask, dst_mask_values=expected_mask, factors=True, + **esmf_args, ) # Without specifying deep_copy=true, the information in weights_dict # would be corrupted when the ESMF regridder is destoyed. @@ -59,7 +62,12 @@ class Regridder: """Regridder for directly interfacing with :mod:`esmpy`.""" def __init__( - self, src, tgt, method=Constants.Method.CONSERVATIVE, precomputed_weights=None + self, + src, + tgt, + method=Constants.Method.CONSERVATIVE, + esmf_args=None, + precomputed_weights=None, ): """ Create a regridder from descriptions of horizontal grids/meshes. @@ -85,6 +93,8 @@ def __init__( If ``None``, :mod:`esmpy` will be used to calculate regridding weights. Otherwise, :mod:`esmpy` will be bypassed and ``precomputed_weights`` will be used as the regridding weights. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. """ self.src = src self.tgt = tgt @@ -98,6 +108,7 @@ def __init__( src.make_esmf_field(), tgt.make_esmf_field(), regrid_method=method.value, + esmf_args=esmf_args, ) self.weight_matrix = _weights_dict_to_sparse_array( weights_dict, diff --git a/esmf_regrid/experimental/io.py b/esmf_regrid/experimental/io.py index 9b92708e..26cce879 100644 --- a/esmf_regrid/experimental/io.py +++ b/esmf_regrid/experimental/io.py @@ -22,6 +22,7 @@ GridRecord, MeshRecord, ) +from esmf_regrid import esmpy SUPPORTED_REGRIDDERS = [ @@ -48,6 +49,33 @@ RESOLUTION = "resolution" SOURCE_RESOLUTION = "src_resolution" TARGET_RESOLUTION = "tgt_resolution" +ESMF_ARGS = "esmf_args" +# TODO: check this list is accurate +VALID_ESMF_KWARGS = [ + # "regrid_method", + "pole_method", + "regrid_pole_npoints", + "line_type", + "extrap_method", + "extrap_num_src_pnts", + "extrap_dist_exponent", + "extrap_num_levels", + "unmapped_action", + # "ignore_degenerate", + "large_file", +] +# REGRID_METHOD_DICT = {e.name: e for e in esmpy.RegridMethod} +POLE_METHOD_DICT = {e.name: e for e in esmpy.PoleMethod} +LINE_TYPE_DICT = {e.name: e for e in esmpy.LineType} +EXTRAP_METHOD_DICT = {e.name: e for e in esmpy.ExtrapMethod} +UNMAPPED_ACTION_DICT = {e.name: e for e in esmpy.UnmappedAction} +ESMF_ENUM_ARGS = { + # "regrid_method": REGRID_METHOD_DICT, + "pole_method": POLE_METHOD_DICT, + "line_type": LINE_TYPE_DICT, + "extrap_method": EXTRAP_METHOD_DICT, + "unmapped_action": UNMAPPED_ACTION_DICT, +} def _add_mask_to_cube(mask, cube, name): @@ -254,6 +282,21 @@ def _standard_mesh_cube(mesh, location, name): weights_cube.add_aux_coord(row_coord, 0) weights_cube.add_aux_coord(col_coord, 0) + esmf_args = rg.esmf_args + if esmf_args is None: + esmf_args = {} + for arg in esmf_args.keys(): + if arg not in VALID_ESMF_KWARGS: + raise KeyError(f"{arg} is not considered a valid argument to pass to ESMF.") + esmf_arg_attributes = { + k: v.name if hasattr(v, "name") else int(v) if isinstance(v, bool) else v + for k, v in esmf_args.items() + } + esmf_arg_coord = AuxCoord( + 0, var_name=ESMF_ARGS, long_name=ESMF_ARGS, attributes=esmf_arg_attributes + ) + weights_cube.add_aux_coord(esmf_arg_coord) + weight_shape_cube = Cube( weight_shape, var_name=WEIGHTS_SHAPE_NAME, @@ -340,6 +383,11 @@ def load_regridder(filename): else: use_tgt_mask = False + esmf_args = weights_cube.coord(ESMF_ARGS).attributes + for arg, dict in ESMF_ENUM_ARGS.items(): + if arg in esmf_args: + esmf_args[arg] = dict[esmf_args[arg]] + if scheme is GridToMeshESMFRegridder: resolution_keyword = SOURCE_RESOLUTION kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} @@ -363,6 +411,7 @@ def load_regridder(filename): precomputed_weights=weight_matrix, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, + esmf_args=esmf_args, **kwargs, ) diff --git a/esmf_regrid/experimental/unstructured_scheme.py b/esmf_regrid/experimental/unstructured_scheme.py index 172bfd4e..46f445cb 100644 --- a/esmf_regrid/experimental/unstructured_scheme.py +++ b/esmf_regrid/experimental/unstructured_scheme.py @@ -119,6 +119,7 @@ def __init__( tgt_resolution=None, use_src_mask=False, use_tgt_mask=False, + esmf_args=None, ): """ Create regridder for conversions between source mesh and target grid. @@ -157,6 +158,8 @@ def __init__( a boolean value. If True, this array is taken from the mask on the data in ``tgt``. If False, no mask will be taken and all points will be used in weights calculation. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Raises ------ @@ -177,6 +180,7 @@ def __init__( tgt_resolution=tgt_resolution, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, + esmf_args=esmf_args, ) self.resolution = tgt_resolution self.mesh, self.location = self._src @@ -286,6 +290,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location=None, + esmf_args=None, ): """ Create regridder for conversions between source grid and target mesh. @@ -328,6 +333,8 @@ def __init__( tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Raises ------ @@ -348,6 +355,7 @@ def __init__( use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, + esmf_args=esmf_args, ) self.resolution = src_resolution self.mesh, self.location = self._tgt diff --git a/esmf_regrid/schemes.py b/esmf_regrid/schemes.py index 69ba70e1..ba60f2cb 100644 --- a/esmf_regrid/schemes.py +++ b/esmf_regrid/schemes.py @@ -19,7 +19,7 @@ except ImportError: raise exc -from esmf_regrid import check_method, Constants +from esmf_regrid import check_method, Constants, esmpy from esmf_regrid.esmf_regridder import GridInfo, RefinedGridInfo, Regridder from esmf_regrid.experimental.unstructured_regrid import MeshInfo @@ -33,6 +33,12 @@ "regrid_rectilinear_to_rectilinear", ] +STANDAR_BILINEAR_EXTAP_ARGS = { + "extrap_method": esmpy.ExtrapMethod.NEAREST_IDAVG, + "extrap_num_src_pnts": 2, + "extrap_dist_exponent": 1, +} + def _get_coord(cube, axis): try: @@ -292,6 +298,43 @@ def _regrid_along_dims(data, regridder, dims, num_out_dims, mdtol): return result +def _check_esmf_args(kwargs): + #TODO: raise proper warning messages + #TODO: check invalid and valid lists are appropriate + if kwargs is not None: + if not isinstance(kwargs, dict): + raise TypeError("") + invalid_kwargs = [ + "filename", + "norm_type", + "rh_filename", + "regrid_method", + "src_mask_values", + "dst_mask_values", + "factors", + "src_frac_field", + "dst_frac_field", + "ignore_degenerate", #TODO: check if this is worth controling + ] + valid_kwargs = [ + "pole method", + "regrid_pole_npoints", + "line_type", + "extrap_method", + "extrap_num_src_pnts", + "extrap_dist_exponent", + "extrap_num_levels", + "unmapped_action", + # "ignore_degenerate", + "large_file", + ] + for kwarg in kwargs.keys(): + if kwarg in invalid_kwargs: + raise ValueError("") + if kwarg not in valid_kwargs: + raise ValueError("") + + def _map_complete_blocks( src, func, active_dims, out_sizes, *args, dtype=None, **kwargs ): @@ -534,6 +577,7 @@ def _regrid_rectilinear_to_rectilinear__prepare( tgt_resolution=None, src_mask=None, tgt_mask=None, + esmf_args=None, ): """ First (setup) part of 'regrid_rectilinear_to_rectilinear'. @@ -556,8 +600,14 @@ def _regrid_rectilinear_to_rectilinear__prepare( srcinfo = _make_gridinfo(src_grid_cube, method, src_resolution, src_mask) tgtinfo = _make_gridinfo(tgt_grid_cube, method, tgt_resolution, tgt_mask) + _check_esmf_args(esmf_args) + regridder = Regridder( - srcinfo, tgtinfo, method=method, precomputed_weights=precomputed_weights + srcinfo, + tgtinfo, + method=method, + esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) regrid_info = RegridInfo( @@ -620,6 +670,7 @@ def _regrid_unstructured_to_rectilinear__prepare( tgt_resolution=None, src_mask=None, tgt_mask=None, + esmf_args=None, ): """ First (setup) part of 'regrid_unstructured_to_rectilinear'. @@ -638,8 +689,14 @@ def _regrid_unstructured_to_rectilinear__prepare( meshinfo = _make_meshinfo(src_mesh_cube, method, src_mask, "source") gridinfo = _make_gridinfo(target_grid_cube, method, tgt_resolution, tgt_mask) + _check_esmf_args(esmf_args) + regridder = Regridder( - meshinfo, gridinfo, method=method, precomputed_weights=precomputed_weights + meshinfo, + gridinfo, + method=method, + esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) regrid_info = RegridInfo( @@ -706,6 +763,7 @@ def _regrid_rectilinear_to_unstructured__prepare( src_mask=None, tgt_mask=None, tgt_location=None, + esmf_args=None, ): """ First (setup) part of 'regrid_rectilinear_to_unstructured'. @@ -734,8 +792,14 @@ def _regrid_rectilinear_to_unstructured__prepare( ) gridinfo = _make_gridinfo(src_grid_cube, method, src_resolution, src_mask) + _check_esmf_args(esmf_args) + regridder = Regridder( - gridinfo, meshinfo, method=method, precomputed_weights=precomputed_weights + gridinfo, + meshinfo, + method=method, + esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) regrid_info = RegridInfo( @@ -804,6 +868,7 @@ def _regrid_unstructured_to_unstructured__prepare( tgt_mask=None, src_location=None, tgt_location=None, + esmf_args=None, ): """ First (setup) part of 'regrid_unstructured_to_unstructured'. @@ -833,6 +898,7 @@ def _regrid_unstructured_to_unstructured__prepare( tgt_meshinfo, method=method, precomputed_weights=precomputed_weights, + esmf_args=esmf_args, ) regrid_info = RegridInfo( @@ -955,7 +1021,12 @@ class ESMFAreaWeighted: """ def __init__( - self, mdtol=0, use_src_mask=False, use_tgt_mask=False, tgt_location="face" + self, + mdtol=0, + use_src_mask=False, + use_tgt_mask=False, + tgt_location="face", + esmf_args={}, ): """ Area-weighted scheme for regridding between rectilinear grids. @@ -979,6 +1050,8 @@ def __init__( tgt_location : str or None, default="face" Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. """ if not (0 <= mdtol <= 1): @@ -992,10 +1065,20 @@ def __init__( self.use_src_mask = use_src_mask self.use_tgt_mask = use_tgt_mask self.tgt_location = "face" + _check_esmf_args(esmf_args) + self.esmf_args = esmf_args def __repr__(self): """Return a representation of the class.""" - return "ESMFAreaWeighted(mdtol={})".format(self.mdtol) + result = ( + f"ESMFAreaWeighted(" + f" mdtol={self.mdtol}," + f" use_src_mask={self.use_src_mask}," + f" use_tgt_mask={self.use_tgt_mask}," + f" esmf_args={self.esmf_args}," + f")" + ) + return result def regridder( self, @@ -1006,6 +1089,7 @@ def regridder( use_src_mask=None, use_tgt_mask=None, tgt_location="face", + esmf_args=None, ): """ Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1035,6 +1119,8 @@ def regridder( tgt_location : str or None, default="face" Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Returns @@ -1055,6 +1141,9 @@ def regridder( use_src_mask = self.use_src_mask if use_tgt_mask is None: use_tgt_mask = self.use_tgt_mask + if esmf_args is None: + esmf_args = self.esmf_args + if tgt_location is not None and tgt_location != "face": raise ValueError( "For area weighted regridding, target location must be 'face'." @@ -1068,6 +1157,7 @@ def regridder( use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location="face", + esmf_args=esmf_args, ) @@ -1081,7 +1171,13 @@ class ESMFBilinear: """ def __init__( - self, mdtol=0, use_src_mask=False, use_tgt_mask=False, tgt_location=None + self, + mdtol=0, + use_src_mask=False, + use_tgt_mask=False, + tgt_location=None, + extrapolate_gaps=False, + esmf_args={}, ): """ Area-weighted scheme for regridding between rectilinear grids. @@ -1101,6 +1197,12 @@ def __init__( tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + extrapolate_gaps : bool, default=False + Use a standard set of ESMF arguments for extrapolation which achieves + continuity with bilinear regridding. Useful for situations where gaps + between cells would be masked. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. """ if not (0 <= mdtol <= 1): @@ -1110,10 +1212,22 @@ def __init__( self.use_src_mask = use_src_mask self.use_tgt_mask = use_tgt_mask self.tgt_location = tgt_location + if extrapolate_gaps: + esmf_args = STANDAR_BILINEAR_EXTAP_ARGS + _check_esmf_args(esmf_args) + self.esmf_args = esmf_args def __repr__(self): """Return a representation of the class.""" - return "ESMFBilinear(mdtol={})".format(self.mdtol) + result = ( + f"ESMFBilinear(" + f" mdtol={self.mdtol}," + f" use_src_mask={self.use_src_mask}," + f" use_tgt_mask={self.use_tgt_mask}," + f" esmf_args={self.esmf_args}," + f")" + ) + return result def regridder( self, @@ -1122,6 +1236,8 @@ def regridder( use_src_mask=None, use_tgt_mask=None, tgt_location=None, + extrapolate_gaps=False, + esmf_args=None, ): """ Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1142,6 +1258,12 @@ def regridder( tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + extrapolate_gaps : bool, default=False + Use a standard set of ESMF arguments for extrapolation which achieves + continuity with bilinear regridding. Useful for situations where gaps + between cells would be masked. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Returns ------- @@ -1163,6 +1285,10 @@ def regridder( use_tgt_mask = self.use_tgt_mask if tgt_location is None: tgt_location = self.tgt_location + if esmf_args is None: + esmf_args = self.esmf_args + if extrapolate_gaps: + esmf_args = STANDAR_BILINEAR_EXTAP_ARGS return ESMFBilinearRegridder( src_grid, tgt_grid, @@ -1170,6 +1296,7 @@ def regridder( use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, + esmf_args=esmf_args, ) @@ -1199,7 +1326,9 @@ class ESMFNearest: the same equivalent space will behave the same. """ - def __init__(self, use_src_mask=False, use_tgt_mask=False, tgt_location=None): + def __init__( + self, use_src_mask=False, use_tgt_mask=False, tgt_location=None, esmf_args={} + ): """ Nearest neighbour scheme for regridding between rectilinear grids. @@ -1214,14 +1343,25 @@ def __init__(self, use_src_mask=False, use_tgt_mask=False, tgt_location=None): tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. """ self.use_src_mask = use_src_mask self.use_tgt_mask = use_tgt_mask self.tgt_location = tgt_location + _check_esmf_args(esmf_args) + self.esmf_args = esmf_args def __repr__(self): """Return a representation of the class.""" - return "ESMFNearest()" + result = ( + f"ESMFNearest(" + f" use_src_mask={self.use_src_mask}," + f" use_tgt_mask={self.use_tgt_mask}," + f" esmf_args={self.esmf_args}," + f")" + ) + return result def regridder( self, @@ -1230,6 +1370,7 @@ def regridder( use_src_mask=None, use_tgt_mask=None, tgt_location=None, + esmf_args=None, ): """ Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1250,6 +1391,8 @@ def regridder( tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Returns ------- @@ -1271,12 +1414,15 @@ def regridder( use_tgt_mask = self.use_tgt_mask if tgt_location is None: tgt_location = self.tgt_location + if esmf_args is None: + esmf_args = self.esmf_args return ESMFNearestRegridder( src_grid, tgt_grid, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, + esmf_args=esmf_args, ) @@ -1292,6 +1438,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location=None, + esmf_args=None, **kwargs, ): """ @@ -1340,6 +1487,9 @@ def __init__( self.mdtol = mdtol self.method = method + self.esmf_args = esmf_args + kwargs["esmf_args"] = self.esmf_args + self.src_mask = _get_mask(src, use_src_mask) kwargs["src_mask"] = self.src_mask self.tgt_mask = _get_mask(tgt, use_tgt_mask) @@ -1470,6 +1620,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location="face", + esmf_args=None, ): """ Create regridder for conversions between ``src`` and ``tgt``. @@ -1508,6 +1659,8 @@ def __init__( tgt_location : str or None, default="face" Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Raises ------ @@ -1536,6 +1689,7 @@ def __init__( mdtol=mdtol, precomputed_weights=precomputed_weights, tgt_location="face", + esmf_args=esmf_args, **kwargs, ) @@ -1552,6 +1706,8 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location=None, + extrapolate_gaps=False, + esmf_args=None, ): """ Create regridder for conversions between ``src`` and ``tgt``. @@ -1581,6 +1737,12 @@ def __init__( tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + extrapolate_gaps : bool, default=False + Use a standard set of ESMF arguments for extrapolation which achieves + continuity with bilinear regridding. Useful for situations where gaps + between cells would be masked. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Raises ------ @@ -1588,6 +1750,8 @@ def __init__( If ``use_src_mask`` or ``use_tgt_mask`` are True while the masks on ``src`` or ``tgt`` respectively are not constant over non-horizontal dimensions. """ + if extrapolate_gaps: + esmf_args = STANDAR_BILINEAR_EXTAP_ARGS super().__init__( src, tgt, @@ -1597,6 +1761,7 @@ def __init__( use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, + esmf_args=esmf_args, ) @@ -1611,6 +1776,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location=None, + esmf_args=None, ): """ Create regridder for conversions between ``src`` and ``tgt``. @@ -1634,6 +1800,8 @@ def __init__( tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. + esmf_args : dict, optional + A dictionary of arguments to pass to ESMF. Raises ------ @@ -1650,4 +1818,5 @@ def __init__( use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, + esmf_args=esmf_args, ) diff --git a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py index f17fd941..9c1071c3 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py +++ b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py @@ -10,6 +10,7 @@ ESMFAreaWeightedRegridder, ESMFBilinear, ESMFNearest, + esmpy, ) from esmf_regrid.experimental.io import load_regridder, save_regridder from esmf_regrid.experimental.unstructured_scheme import ( @@ -456,7 +457,12 @@ def test_generic_regridder(tmp_path, src_type, tgt_type, scheme): elif tgt_type == "mesh": tgt = _gridlike_mesh_cube(n_lons_tgt, n_lats_tgt) - original_rg = scheme().regridder(src, tgt) + esmf_args = { + "line_type": esmpy.LineType.CART, + "large_file": True, + } + + original_rg = scheme().regridder(src, tgt, esmf_args=esmf_args) filename = tmp_path / "regridder.nc" save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) @@ -469,6 +475,7 @@ def test_generic_regridder(tmp_path, src_type, tgt_type, scheme): assert original_rg.src_resolution == loaded_rg.src_resolution assert original_rg.tgt_resolution == loaded_rg.tgt_resolution assert original_rg.mdtol == loaded_rg.mdtol + assert original_rg.esmf_args == loaded_rg.esmf_args @pytest.mark.parametrize( From 2232f898658e14c640ab5bd18eea47652f991ad2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 08:30:38 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- esmf_regrid/schemes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/esmf_regrid/schemes.py b/esmf_regrid/schemes.py index ba60f2cb..4217e262 100644 --- a/esmf_regrid/schemes.py +++ b/esmf_regrid/schemes.py @@ -299,8 +299,8 @@ def _regrid_along_dims(data, regridder, dims, num_out_dims, mdtol): def _check_esmf_args(kwargs): - #TODO: raise proper warning messages - #TODO: check invalid and valid lists are appropriate + # TODO: raise proper warning messages + # TODO: check invalid and valid lists are appropriate if kwargs is not None: if not isinstance(kwargs, dict): raise TypeError("") @@ -314,7 +314,7 @@ def _check_esmf_args(kwargs): "factors", "src_frac_field", "dst_frac_field", - "ignore_degenerate", #TODO: check if this is worth controling + "ignore_degenerate", # TODO: check if this is worth controling ] valid_kwargs = [ "pole method", From 7c788eec47416ccc779ceba90b11c9c7257a189b Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Fri, 23 Aug 2024 09:26:47 +0100 Subject: [PATCH 3/6] fix valid_kwargs, etc --- esmf_regrid/esmf_regridder.py | 7 +++++-- esmf_regrid/experimental/io.py | 6 +----- esmf_regrid/schemes.py | 12 ++++++------ .../unit/experimental/io/test_round_tripping.py | 2 ++ 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/esmf_regrid/esmf_regridder.py b/esmf_regrid/esmf_regridder.py index af136099..5b9dbaaf 100644 --- a/esmf_regrid/esmf_regridder.py +++ b/esmf_regrid/esmf_regridder.py @@ -19,14 +19,17 @@ def _get_regrid_weights_dict(src_field, tgt_field, regrid_method, esmf_args=None): if esmf_args is None: esmf_args = {} + # Provide default values + if "ignore_degenerate" not in esmf_args: + esmf_args["ignore_degenerate"] = True + if "unmapped_action" not in esmf_args: + esmf_args["unmapped_action"] = esmpy.UnmappedAction.IGNORE # The value, in array form, that ESMF should treat as an affirmative mask. expected_mask = np.array([True]) regridder = esmpy.Regrid( src_field, tgt_field, - ignore_degenerate=True, regrid_method=regrid_method, - unmapped_action=esmpy.UnmappedAction.IGNORE, # Choosing the norm_type DSTAREA allows for mdtol type operations # to be performed using the weights information later on. norm_type=esmpy.NormType.DSTAREA, diff --git a/esmf_regrid/experimental/io.py b/esmf_regrid/experimental/io.py index 26cce879..15244c2d 100644 --- a/esmf_regrid/experimental/io.py +++ b/esmf_regrid/experimental/io.py @@ -50,9 +50,7 @@ SOURCE_RESOLUTION = "src_resolution" TARGET_RESOLUTION = "tgt_resolution" ESMF_ARGS = "esmf_args" -# TODO: check this list is accurate VALID_ESMF_KWARGS = [ - # "regrid_method", "pole_method", "regrid_pole_npoints", "line_type", @@ -61,16 +59,14 @@ "extrap_dist_exponent", "extrap_num_levels", "unmapped_action", - # "ignore_degenerate", + "ignore_degenerate", "large_file", ] -# REGRID_METHOD_DICT = {e.name: e for e in esmpy.RegridMethod} POLE_METHOD_DICT = {e.name: e for e in esmpy.PoleMethod} LINE_TYPE_DICT = {e.name: e for e in esmpy.LineType} EXTRAP_METHOD_DICT = {e.name: e for e in esmpy.ExtrapMethod} UNMAPPED_ACTION_DICT = {e.name: e for e in esmpy.UnmappedAction} ESMF_ENUM_ARGS = { - # "regrid_method": REGRID_METHOD_DICT, "pole_method": POLE_METHOD_DICT, "line_type": LINE_TYPE_DICT, "extrap_method": EXTRAP_METHOD_DICT, diff --git a/esmf_regrid/schemes.py b/esmf_regrid/schemes.py index 4217e262..048c7b24 100644 --- a/esmf_regrid/schemes.py +++ b/esmf_regrid/schemes.py @@ -300,10 +300,9 @@ def _regrid_along_dims(data, regridder, dims, num_out_dims, mdtol): def _check_esmf_args(kwargs): # TODO: raise proper warning messages - # TODO: check invalid and valid lists are appropriate if kwargs is not None: if not isinstance(kwargs, dict): - raise TypeError("") + raise TypeError(f"Expected `esmf_args` to be a dict, got a {type(kwargs)}.") invalid_kwargs = [ "filename", "norm_type", @@ -314,7 +313,6 @@ def _check_esmf_args(kwargs): "factors", "src_frac_field", "dst_frac_field", - "ignore_degenerate", # TODO: check if this is worth controling ] valid_kwargs = [ "pole method", @@ -325,14 +323,16 @@ def _check_esmf_args(kwargs): "extrap_dist_exponent", "extrap_num_levels", "unmapped_action", - # "ignore_degenerate", + "ignore_degenerate", "large_file", ] for kwarg in kwargs.keys(): if kwarg in invalid_kwargs: - raise ValueError("") + msg = f"{kwarg} is not an argument which can be controlled by `esmf_args`." + raise ValueError(msg) if kwarg not in valid_kwargs: - raise ValueError("") + msg = f"{kwarg} is not a valid argument for `esmpy.Regrid`." + raise ValueError(msg) def _map_complete_blocks( diff --git a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py index 9c1071c3..c0e78a69 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py +++ b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py @@ -458,6 +458,8 @@ def test_generic_regridder(tmp_path, src_type, tgt_type, scheme): tgt = _gridlike_mesh_cube(n_lons_tgt, n_lats_tgt) esmf_args = { + "unmapped_action": esmpy.UnmappedAction.ERROR, + "ignore_degenerate": False, "line_type": esmpy.LineType.CART, "large_file": True, } From 2fb00e3335140d47d5b3f0ee277709f3e3c1ee71 Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Sun, 1 Sep 2024 22:24:19 +0100 Subject: [PATCH 4/6] add tests --- esmf_regrid/esmf_regridder.py | 2 + esmf_regrid/schemes.py | 13 +++-- esmf_regrid/tests/unit/schemes/__init__.py | 52 +++++++++++++++++++ .../unit/schemes/test_ESMFAreaWeighted.py | 6 +++ .../tests/unit/schemes/test_ESMFBilinear.py | 5 ++ .../schemes/test_ESMFBilinearRegridder.py | 34 ++++++++++++ .../tests/unit/schemes/test_ESMFNearest.py | 6 +++ 7 files changed, 111 insertions(+), 7 deletions(-) diff --git a/esmf_regrid/esmf_regridder.py b/esmf_regrid/esmf_regridder.py index 5b9dbaaf..a2641c7a 100644 --- a/esmf_regrid/esmf_regridder.py +++ b/esmf_regrid/esmf_regridder.py @@ -19,6 +19,8 @@ def _get_regrid_weights_dict(src_field, tgt_field, regrid_method, esmf_args=None): if esmf_args is None: esmf_args = {} + else: + esmf_args = esmf_args.copy() # Provide default values if "ignore_degenerate" not in esmf_args: esmf_args["ignore_degenerate"] = True diff --git a/esmf_regrid/schemes.py b/esmf_regrid/schemes.py index 048c7b24..d768038d 100644 --- a/esmf_regrid/schemes.py +++ b/esmf_regrid/schemes.py @@ -299,7 +299,6 @@ def _regrid_along_dims(data, regridder, dims, num_out_dims, mdtol): def _check_esmf_args(kwargs): - # TODO: raise proper warning messages if kwargs is not None: if not isinstance(kwargs, dict): raise TypeError(f"Expected `esmf_args` to be a dict, got a {type(kwargs)}.") @@ -328,10 +327,10 @@ def _check_esmf_args(kwargs): ] for kwarg in kwargs.keys(): if kwarg in invalid_kwargs: - msg = f"{kwarg} is not an argument which can be controlled by `esmf_args`." + msg = f"`esmpy.Regrid` argument `{kwarg}` cannot be controlled by `esmf_args`." raise ValueError(msg) if kwarg not in valid_kwargs: - msg = f"{kwarg} is not a valid argument for `esmpy.Regrid`." + msg = f"`{kwarg}` is not a valid argument for `esmpy.Regrid`." raise ValueError(msg) @@ -1438,7 +1437,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location=None, - esmf_args=None, + esmf_args={}, **kwargs, ): """ @@ -1620,7 +1619,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location="face", - esmf_args=None, + esmf_args={}, ): """ Create regridder for conversions between ``src`` and ``tgt``. @@ -1707,7 +1706,7 @@ def __init__( use_tgt_mask=False, tgt_location=None, extrapolate_gaps=False, - esmf_args=None, + esmf_args={}, ): """ Create regridder for conversions between ``src`` and ``tgt``. @@ -1776,7 +1775,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location=None, - esmf_args=None, + esmf_args={}, ): """ Create regridder for conversions between ``src`` and ``tgt``. diff --git a/esmf_regrid/tests/unit/schemes/__init__.py b/esmf_regrid/tests/unit/schemes/__init__.py index c51deffd..b73a62c4 100644 --- a/esmf_regrid/tests/unit/schemes/__init__.py +++ b/esmf_regrid/tests/unit/schemes/__init__.py @@ -12,6 +12,7 @@ _gridlike_mesh, _gridlike_mesh_cube, ) +from esmf_regrid import esmpy def _test_cube_regrid(scheme, src_type, tgt_type): @@ -252,3 +253,54 @@ def _test_dtype_handling(scheme, src_type, tgt_type, in_dtype): assert result.lazy_data().dtype == expected_dtype assert result.data.dtype == expected_dtype + +def _test_esmf_args(scheme): + """Test regridding scheme handles esmf_args as expected.""" + n_lons_src = 6 + n_lons_tgt = 3 + n_lats_src = 4 + n_lats_tgt = 2 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + + src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True) + tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True) + + valid_esmf_args = { + "unmapped_action": esmpy.UnmappedAction.ERROR, + "ignore_degenerate": False, + "line_type": esmpy.LineType.CART, + "large_file": True, + } + + rg_1 = scheme(esmf_args=valid_esmf_args).regridder(src, tgt) + rg_2 = scheme().regridder(src, tgt, esmf_args=valid_esmf_args) + + assert rg_1.esmf_args == valid_esmf_args + assert rg_2.esmf_args == valid_esmf_args + + invalid_esmf_args_duplicate = { + "regrid_method": None + } + invalid_esmf_args_incorrect = { + "invalid_arg": None + } + invalid_esmf_args_type = "invalid_arg" + + match_duplicate = "cannot be controlled by `esmf_args`" + with pytest.raises(ValueError, match=match_duplicate): + rg = scheme(esmf_args=invalid_esmf_args_duplicate) + with pytest.raises(ValueError, match=match_duplicate): + rg = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_duplicate) + + match_incorrect = "is not a valid argument for" + with pytest.raises(ValueError, match=match_incorrect): + rg = scheme(esmf_args=invalid_esmf_args_incorrect) + with pytest.raises(ValueError, match=match_incorrect): + rg = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_incorrect) + + match_type = "Expected `esmf_args` to be a dict, got a " + with pytest.raises(TypeError, match=match_type): + rg = scheme(esmf_args=invalid_esmf_args_type) + with pytest.raises(TypeError, match=match_type): + rg = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_type) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py index 50af91b3..f48ceae9 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py @@ -6,6 +6,7 @@ from esmf_regrid.tests.unit.schemes.__init__ import ( _test_cube_regrid, _test_dtype_handling, + _test_esmf_args, _test_invalid_mdtol, _test_mask_from_init, _test_mask_from_regridder, @@ -81,3 +82,8 @@ def test_dtype_handling(src_tgt_types, in_dtype): """Test regridding scheme handles dtype as expected.""" src_type, tgt_type = src_tgt_types _test_dtype_handling(ESMFAreaWeighted, src_type, tgt_type, in_dtype) + + +def test_esmf_args(): + """Test regridding scheme handles esmf_args as expected.""" + _test_esmf_args(ESMFAreaWeighted) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py b/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py index f6a9e30f..b05221cd 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py @@ -6,6 +6,7 @@ from esmf_regrid.tests.unit.schemes.__init__ import ( _test_cube_regrid, _test_dtype_handling, + _test_esmf_args, _test_invalid_mdtol, _test_mask_from_init, _test_mask_from_regridder, @@ -70,3 +71,7 @@ def test_dtype_handling(src_tgt_types, in_dtype): """Test regridding scheme handles dtype as expected.""" src_type, tgt_type = src_tgt_types _test_dtype_handling(ESMFBilinear, src_type, tgt_type, in_dtype) + +def test_esmf_args(): + """Test regridding scheme handles esmf_args as expected.""" + _test_esmf_args(ESMFBilinear) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py b/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py index da60f416..3cb494a2 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py @@ -9,6 +9,7 @@ _curvilinear_cube, _grid_cube, ) +from esmf_regrid import esmpy def test_dim_switching(): @@ -329,3 +330,36 @@ def test_regrid_data(): ) result = rg(src) np.testing.assert_allclose(expected_data, result.data) + +def test_extrapolate_gaps(): + n_lons = 6 + n_lats = 5 + src_lon_bounds = (-140, 180) + tgt_lon_bounds = (-180, 180) + src_lat_bounds = (-80, 90) + tgt_lat_bounds = (-90, 90) + src = _grid_cube(n_lons, n_lats, src_lon_bounds, src_lat_bounds, circular=False) + src.data = np.arange(n_lons * n_lats).reshape(n_lats, n_lons) + tgt = _grid_cube(n_lons, n_lats, tgt_lon_bounds, tgt_lat_bounds, circular=True) + + extrapolate_regridder = ESMFBilinearRegridder( + src, tgt, extrapolate_gaps=True + ) + normal_regridder = ESMFBilinearRegridder( + src, tgt, extrapolate_gaps=False + ) + + extrapolate_result = extrapolate_regridder(src) + normal_result = normal_regridder(src) + + assert not np.ma.is_masked(extrapolate_result.data) + assert np.ma.is_masked(normal_result.data) + + expected_args = { + "extrap_method": esmpy.ExtrapMethod.NEAREST_IDAVG, + "extrap_num_src_pnts": 2, + "extrap_dist_exponent": 1, +} + assert extrapolate_regridder.esmf_args == expected_args + assert normal_regridder.esmf_args == {} + diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py b/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py index 6f146e6c..db90d830 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py @@ -7,6 +7,7 @@ from esmf_regrid.schemes import ESMFNearest from esmf_regrid.tests.unit.schemes.__init__ import ( _test_dtype_handling, + _test_esmf_args, _test_mask_from_init, _test_mask_from_regridder, _test_non_degree_crs, @@ -113,3 +114,8 @@ def test_dtype_handling(src_tgt_types, in_dtype): """Test regridding scheme handles dtype as expected.""" src_type, tgt_type = src_tgt_types _test_dtype_handling(ESMFNearest, src_type, tgt_type, in_dtype) + + +def test_esmf_args(): + """Test regridding scheme handles esmf_args as expected.""" + _test_esmf_args(ESMFNearest) From 4b82a64297b5957680bef5acba05447eb90ec05a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Sep 2024 21:24:38 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- esmf_regrid/tests/unit/schemes/__init__.py | 9 +++------ .../tests/unit/schemes/test_ESMFBilinear.py | 1 + .../unit/schemes/test_ESMFBilinearRegridder.py | 18 +++++++----------- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/esmf_regrid/tests/unit/schemes/__init__.py b/esmf_regrid/tests/unit/schemes/__init__.py index b73a62c4..a4ba3b34 100644 --- a/esmf_regrid/tests/unit/schemes/__init__.py +++ b/esmf_regrid/tests/unit/schemes/__init__.py @@ -254,6 +254,7 @@ def _test_dtype_handling(scheme, src_type, tgt_type, in_dtype): assert result.lazy_data().dtype == expected_dtype assert result.data.dtype == expected_dtype + def _test_esmf_args(scheme): """Test regridding scheme handles esmf_args as expected.""" n_lons_src = 6 @@ -279,12 +280,8 @@ def _test_esmf_args(scheme): assert rg_1.esmf_args == valid_esmf_args assert rg_2.esmf_args == valid_esmf_args - invalid_esmf_args_duplicate = { - "regrid_method": None - } - invalid_esmf_args_incorrect = { - "invalid_arg": None - } + invalid_esmf_args_duplicate = {"regrid_method": None} + invalid_esmf_args_incorrect = {"invalid_arg": None} invalid_esmf_args_type = "invalid_arg" match_duplicate = "cannot be controlled by `esmf_args`" diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py b/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py index b05221cd..8b4a5a1f 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py @@ -72,6 +72,7 @@ def test_dtype_handling(src_tgt_types, in_dtype): src_type, tgt_type = src_tgt_types _test_dtype_handling(ESMFBilinear, src_type, tgt_type, in_dtype) + def test_esmf_args(): """Test regridding scheme handles esmf_args as expected.""" _test_esmf_args(ESMFBilinear) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py b/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py index 3cb494a2..2d8e4b96 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py @@ -331,6 +331,7 @@ def test_regrid_data(): result = rg(src) np.testing.assert_allclose(expected_data, result.data) + def test_extrapolate_gaps(): n_lons = 6 n_lats = 5 @@ -342,12 +343,8 @@ def test_extrapolate_gaps(): src.data = np.arange(n_lons * n_lats).reshape(n_lats, n_lons) tgt = _grid_cube(n_lons, n_lats, tgt_lon_bounds, tgt_lat_bounds, circular=True) - extrapolate_regridder = ESMFBilinearRegridder( - src, tgt, extrapolate_gaps=True - ) - normal_regridder = ESMFBilinearRegridder( - src, tgt, extrapolate_gaps=False - ) + extrapolate_regridder = ESMFBilinearRegridder(src, tgt, extrapolate_gaps=True) + normal_regridder = ESMFBilinearRegridder(src, tgt, extrapolate_gaps=False) extrapolate_result = extrapolate_regridder(src) normal_result = normal_regridder(src) @@ -356,10 +353,9 @@ def test_extrapolate_gaps(): assert np.ma.is_masked(normal_result.data) expected_args = { - "extrap_method": esmpy.ExtrapMethod.NEAREST_IDAVG, - "extrap_num_src_pnts": 2, - "extrap_dist_exponent": 1, -} + "extrap_method": esmpy.ExtrapMethod.NEAREST_IDAVG, + "extrap_num_src_pnts": 2, + "extrap_dist_exponent": 1, + } assert extrapolate_regridder.esmf_args == expected_args assert normal_regridder.esmf_args == {} - From 6c23d6586ac3dfe9dc605482e8a8a117d9f1d959 Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Tue, 3 Sep 2024 23:48:42 +0100 Subject: [PATCH 6/6] flake8 --- esmf_regrid/experimental/io.py | 4 +--- esmf_regrid/tests/unit/schemes/__init__.py | 14 +++++++------- .../unit/schemes/test_ESMFBilinearRegridder.py | 3 ++- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/esmf_regrid/experimental/io.py b/esmf_regrid/experimental/io.py index 15244c2d..0163cccb 100644 --- a/esmf_regrid/experimental/io.py +++ b/esmf_regrid/experimental/io.py @@ -9,8 +9,7 @@ import scipy.sparse import esmf_regrid -from esmf_regrid import _load_context -from esmf_regrid import check_method, Constants +from esmf_regrid import _load_context, check_method, Constants, esmpy from esmf_regrid.experimental.unstructured_scheme import ( GridToMeshESMFRegridder, MeshToGridESMFRegridder, @@ -22,7 +21,6 @@ GridRecord, MeshRecord, ) -from esmf_regrid import esmpy SUPPORTED_REGRIDDERS = [ diff --git a/esmf_regrid/tests/unit/schemes/__init__.py b/esmf_regrid/tests/unit/schemes/__init__.py index a4ba3b34..92ab695c 100644 --- a/esmf_regrid/tests/unit/schemes/__init__.py +++ b/esmf_regrid/tests/unit/schemes/__init__.py @@ -6,13 +6,13 @@ from numpy import ma import pytest +from esmf_regrid import esmpy from esmf_regrid.schemes import ESMFAreaWeighted, ESMFBilinear, ESMFNearest from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import _grid_cube from esmf_regrid.tests.unit.schemes.test__mesh_to_MeshInfo import ( _gridlike_mesh, _gridlike_mesh_cube, ) -from esmf_regrid import esmpy def _test_cube_regrid(scheme, src_type, tgt_type): @@ -286,18 +286,18 @@ def _test_esmf_args(scheme): match_duplicate = "cannot be controlled by `esmf_args`" with pytest.raises(ValueError, match=match_duplicate): - rg = scheme(esmf_args=invalid_esmf_args_duplicate) + _ = scheme(esmf_args=invalid_esmf_args_duplicate) with pytest.raises(ValueError, match=match_duplicate): - rg = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_duplicate) + _ = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_duplicate) match_incorrect = "is not a valid argument for" with pytest.raises(ValueError, match=match_incorrect): - rg = scheme(esmf_args=invalid_esmf_args_incorrect) + _ = scheme(esmf_args=invalid_esmf_args_incorrect) with pytest.raises(ValueError, match=match_incorrect): - rg = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_incorrect) + _ = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_incorrect) match_type = "Expected `esmf_args` to be a dict, got a " with pytest.raises(TypeError, match=match_type): - rg = scheme(esmf_args=invalid_esmf_args_type) + _ = scheme(esmf_args=invalid_esmf_args_type) with pytest.raises(TypeError, match=match_type): - rg = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_type) + _ = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_type) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py b/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py index 2d8e4b96..0a39172d 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py @@ -4,12 +4,12 @@ import numpy as np import pytest +from esmf_regrid import esmpy from esmf_regrid.schemes import ESMFBilinearRegridder from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import ( _curvilinear_cube, _grid_cube, ) -from esmf_regrid import esmpy def test_dim_switching(): @@ -333,6 +333,7 @@ def test_regrid_data(): def test_extrapolate_gaps(): + """Test that `extrapolate_gaps` argument works as expected.""" n_lons = 6 n_lats = 5 src_lon_bounds = (-140, 180)