diff --git a/esmf_regrid/esmf_regridder.py b/esmf_regrid/esmf_regridder.py index 5b9dbaaf..a2641c7a 100644 --- a/esmf_regrid/esmf_regridder.py +++ b/esmf_regrid/esmf_regridder.py @@ -19,6 +19,8 @@ def _get_regrid_weights_dict(src_field, tgt_field, regrid_method, esmf_args=None): if esmf_args is None: esmf_args = {} + else: + esmf_args = esmf_args.copy() # Provide default values if "ignore_degenerate" not in esmf_args: esmf_args["ignore_degenerate"] = True diff --git a/esmf_regrid/schemes.py b/esmf_regrid/schemes.py index 048c7b24..d768038d 100644 --- a/esmf_regrid/schemes.py +++ b/esmf_regrid/schemes.py @@ -299,7 +299,6 @@ def _regrid_along_dims(data, regridder, dims, num_out_dims, mdtol): def _check_esmf_args(kwargs): - # TODO: raise proper warning messages if kwargs is not None: if not isinstance(kwargs, dict): raise TypeError(f"Expected `esmf_args` to be a dict, got a {type(kwargs)}.") @@ -328,10 +327,10 @@ def _check_esmf_args(kwargs): ] for kwarg in kwargs.keys(): if kwarg in invalid_kwargs: - msg = f"{kwarg} is not an argument which can be controlled by `esmf_args`." + msg = f"`esmpy.Regrid` argument `{kwarg}` cannot be controlled by `esmf_args`." raise ValueError(msg) if kwarg not in valid_kwargs: - msg = f"{kwarg} is not a valid argument for `esmpy.Regrid`." + msg = f"`{kwarg}` is not a valid argument for `esmpy.Regrid`." raise ValueError(msg) @@ -1438,7 +1437,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location=None, - esmf_args=None, + esmf_args={}, **kwargs, ): """ @@ -1620,7 +1619,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location="face", - esmf_args=None, + esmf_args={}, ): """ Create regridder for conversions between ``src`` and ``tgt``. @@ -1707,7 +1706,7 @@ def __init__( use_tgt_mask=False, tgt_location=None, extrapolate_gaps=False, - esmf_args=None, + esmf_args={}, ): """ Create regridder for conversions between ``src`` and ``tgt``. @@ -1776,7 +1775,7 @@ def __init__( use_src_mask=False, use_tgt_mask=False, tgt_location=None, - esmf_args=None, + esmf_args={}, ): """ Create regridder for conversions between ``src`` and ``tgt``. diff --git a/esmf_regrid/tests/unit/schemes/__init__.py b/esmf_regrid/tests/unit/schemes/__init__.py index c51deffd..b73a62c4 100644 --- a/esmf_regrid/tests/unit/schemes/__init__.py +++ b/esmf_regrid/tests/unit/schemes/__init__.py @@ -12,6 +12,7 @@ _gridlike_mesh, _gridlike_mesh_cube, ) +from esmf_regrid import esmpy def _test_cube_regrid(scheme, src_type, tgt_type): @@ -252,3 +253,54 @@ def _test_dtype_handling(scheme, src_type, tgt_type, in_dtype): assert result.lazy_data().dtype == expected_dtype assert result.data.dtype == expected_dtype + +def _test_esmf_args(scheme): + """Test regridding scheme handles esmf_args as expected.""" + n_lons_src = 6 + n_lons_tgt = 3 + n_lats_src = 4 + n_lats_tgt = 2 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + + src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True) + tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True) + + valid_esmf_args = { + "unmapped_action": esmpy.UnmappedAction.ERROR, + "ignore_degenerate": False, + "line_type": esmpy.LineType.CART, + "large_file": True, + } + + rg_1 = scheme(esmf_args=valid_esmf_args).regridder(src, tgt) + rg_2 = scheme().regridder(src, tgt, esmf_args=valid_esmf_args) + + assert rg_1.esmf_args == valid_esmf_args + assert rg_2.esmf_args == valid_esmf_args + + invalid_esmf_args_duplicate = { + "regrid_method": None + } + invalid_esmf_args_incorrect = { + "invalid_arg": None + } + invalid_esmf_args_type = "invalid_arg" + + match_duplicate = "cannot be controlled by `esmf_args`" + with pytest.raises(ValueError, match=match_duplicate): + rg = scheme(esmf_args=invalid_esmf_args_duplicate) + with pytest.raises(ValueError, match=match_duplicate): + rg = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_duplicate) + + match_incorrect = "is not a valid argument for" + with pytest.raises(ValueError, match=match_incorrect): + rg = scheme(esmf_args=invalid_esmf_args_incorrect) + with pytest.raises(ValueError, match=match_incorrect): + rg = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_incorrect) + + match_type = "Expected `esmf_args` to be a dict, got a " + with pytest.raises(TypeError, match=match_type): + rg = scheme(esmf_args=invalid_esmf_args_type) + with pytest.raises(TypeError, match=match_type): + rg = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_type) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py index 50af91b3..f48ceae9 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py @@ -6,6 +6,7 @@ from esmf_regrid.tests.unit.schemes.__init__ import ( _test_cube_regrid, _test_dtype_handling, + _test_esmf_args, _test_invalid_mdtol, _test_mask_from_init, _test_mask_from_regridder, @@ -81,3 +82,8 @@ def test_dtype_handling(src_tgt_types, in_dtype): """Test regridding scheme handles dtype as expected.""" src_type, tgt_type = src_tgt_types _test_dtype_handling(ESMFAreaWeighted, src_type, tgt_type, in_dtype) + + +def test_esmf_args(): + """Test regridding scheme handles esmf_args as expected.""" + _test_esmf_args(ESMFAreaWeighted) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py b/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py index f6a9e30f..b05221cd 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py @@ -6,6 +6,7 @@ from esmf_regrid.tests.unit.schemes.__init__ import ( _test_cube_regrid, _test_dtype_handling, + _test_esmf_args, _test_invalid_mdtol, _test_mask_from_init, _test_mask_from_regridder, @@ -70,3 +71,7 @@ def test_dtype_handling(src_tgt_types, in_dtype): """Test regridding scheme handles dtype as expected.""" src_type, tgt_type = src_tgt_types _test_dtype_handling(ESMFBilinear, src_type, tgt_type, in_dtype) + +def test_esmf_args(): + """Test regridding scheme handles esmf_args as expected.""" + _test_esmf_args(ESMFBilinear) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py b/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py index da60f416..3cb494a2 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py @@ -9,6 +9,7 @@ _curvilinear_cube, _grid_cube, ) +from esmf_regrid import esmpy def test_dim_switching(): @@ -329,3 +330,36 @@ def test_regrid_data(): ) result = rg(src) np.testing.assert_allclose(expected_data, result.data) + +def test_extrapolate_gaps(): + n_lons = 6 + n_lats = 5 + src_lon_bounds = (-140, 180) + tgt_lon_bounds = (-180, 180) + src_lat_bounds = (-80, 90) + tgt_lat_bounds = (-90, 90) + src = _grid_cube(n_lons, n_lats, src_lon_bounds, src_lat_bounds, circular=False) + src.data = np.arange(n_lons * n_lats).reshape(n_lats, n_lons) + tgt = _grid_cube(n_lons, n_lats, tgt_lon_bounds, tgt_lat_bounds, circular=True) + + extrapolate_regridder = ESMFBilinearRegridder( + src, tgt, extrapolate_gaps=True + ) + normal_regridder = ESMFBilinearRegridder( + src, tgt, extrapolate_gaps=False + ) + + extrapolate_result = extrapolate_regridder(src) + normal_result = normal_regridder(src) + + assert not np.ma.is_masked(extrapolate_result.data) + assert np.ma.is_masked(normal_result.data) + + expected_args = { + "extrap_method": esmpy.ExtrapMethod.NEAREST_IDAVG, + "extrap_num_src_pnts": 2, + "extrap_dist_exponent": 1, +} + assert extrapolate_regridder.esmf_args == expected_args + assert normal_regridder.esmf_args == {} + diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py b/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py index 6f146e6c..db90d830 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py @@ -7,6 +7,7 @@ from esmf_regrid.schemes import ESMFNearest from esmf_regrid.tests.unit.schemes.__init__ import ( _test_dtype_handling, + _test_esmf_args, _test_mask_from_init, _test_mask_from_regridder, _test_non_degree_crs, @@ -113,3 +114,8 @@ def test_dtype_handling(src_tgt_types, in_dtype): """Test regridding scheme handles dtype as expected.""" src_type, tgt_type = src_tgt_types _test_dtype_handling(ESMFNearest, src_type, tgt_type, in_dtype) + + +def test_esmf_args(): + """Test regridding scheme handles esmf_args as expected.""" + _test_esmf_args(ESMFNearest)