Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenworsley committed Sep 1, 2024
1 parent 7c788ee commit 2fb00e3
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 7 deletions.
2 changes: 2 additions & 0 deletions esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
def _get_regrid_weights_dict(src_field, tgt_field, regrid_method, esmf_args=None):
if esmf_args is None:
esmf_args = {}
else:
esmf_args = esmf_args.copy()
# Provide default values
if "ignore_degenerate" not in esmf_args:
esmf_args["ignore_degenerate"] = True
Expand Down
13 changes: 6 additions & 7 deletions esmf_regrid/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ def _regrid_along_dims(data, regridder, dims, num_out_dims, mdtol):


def _check_esmf_args(kwargs):
# TODO: raise proper warning messages
if kwargs is not None:
if not isinstance(kwargs, dict):
raise TypeError(f"Expected `esmf_args` to be a dict, got a {type(kwargs)}.")
Expand Down Expand Up @@ -328,10 +327,10 @@ def _check_esmf_args(kwargs):
]
for kwarg in kwargs.keys():
if kwarg in invalid_kwargs:
msg = f"{kwarg} is not an argument which can be controlled by `esmf_args`."
msg = f"`esmpy.Regrid` argument `{kwarg}` cannot be controlled by `esmf_args`."
raise ValueError(msg)
if kwarg not in valid_kwargs:
msg = f"{kwarg} is not a valid argument for `esmpy.Regrid`."
msg = f"`{kwarg}` is not a valid argument for `esmpy.Regrid`."
raise ValueError(msg)


Expand Down Expand Up @@ -1438,7 +1437,7 @@ def __init__(
use_src_mask=False,
use_tgt_mask=False,
tgt_location=None,
esmf_args=None,
esmf_args={},
**kwargs,
):
"""
Expand Down Expand Up @@ -1620,7 +1619,7 @@ def __init__(
use_src_mask=False,
use_tgt_mask=False,
tgt_location="face",
esmf_args=None,
esmf_args={},
):
"""
Create regridder for conversions between ``src`` and ``tgt``.
Expand Down Expand Up @@ -1707,7 +1706,7 @@ def __init__(
use_tgt_mask=False,
tgt_location=None,
extrapolate_gaps=False,
esmf_args=None,
esmf_args={},
):
"""
Create regridder for conversions between ``src`` and ``tgt``.
Expand Down Expand Up @@ -1776,7 +1775,7 @@ def __init__(
use_src_mask=False,
use_tgt_mask=False,
tgt_location=None,
esmf_args=None,
esmf_args={},
):
"""
Create regridder for conversions between ``src`` and ``tgt``.
Expand Down
52 changes: 52 additions & 0 deletions esmf_regrid/tests/unit/schemes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_gridlike_mesh,
_gridlike_mesh_cube,
)
from esmf_regrid import esmpy


def _test_cube_regrid(scheme, src_type, tgt_type):
Expand Down Expand Up @@ -252,3 +253,54 @@ def _test_dtype_handling(scheme, src_type, tgt_type, in_dtype):

assert result.lazy_data().dtype == expected_dtype
assert result.data.dtype == expected_dtype

def _test_esmf_args(scheme):
"""Test regridding scheme handles esmf_args as expected."""
n_lons_src = 6
n_lons_tgt = 3
n_lats_src = 4
n_lats_tgt = 2
lon_bounds = (-180, 180)
lat_bounds = (-90, 90)

src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True)
tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True)

valid_esmf_args = {
"unmapped_action": esmpy.UnmappedAction.ERROR,
"ignore_degenerate": False,
"line_type": esmpy.LineType.CART,
"large_file": True,
}

rg_1 = scheme(esmf_args=valid_esmf_args).regridder(src, tgt)
rg_2 = scheme().regridder(src, tgt, esmf_args=valid_esmf_args)

assert rg_1.esmf_args == valid_esmf_args
assert rg_2.esmf_args == valid_esmf_args

invalid_esmf_args_duplicate = {
"regrid_method": None
}
invalid_esmf_args_incorrect = {
"invalid_arg": None
}
invalid_esmf_args_type = "invalid_arg"

match_duplicate = "cannot be controlled by `esmf_args`"
with pytest.raises(ValueError, match=match_duplicate):
rg = scheme(esmf_args=invalid_esmf_args_duplicate)
with pytest.raises(ValueError, match=match_duplicate):
rg = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_duplicate)

match_incorrect = "is not a valid argument for"
with pytest.raises(ValueError, match=match_incorrect):
rg = scheme(esmf_args=invalid_esmf_args_incorrect)
with pytest.raises(ValueError, match=match_incorrect):
rg = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_incorrect)

match_type = "Expected `esmf_args` to be a dict, got a "
with pytest.raises(TypeError, match=match_type):
rg = scheme(esmf_args=invalid_esmf_args_type)
with pytest.raises(TypeError, match=match_type):
rg = scheme().regridder(src, tgt, esmf_args=invalid_esmf_args_type)
6 changes: 6 additions & 0 deletions esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from esmf_regrid.tests.unit.schemes.__init__ import (
_test_cube_regrid,
_test_dtype_handling,
_test_esmf_args,
_test_invalid_mdtol,
_test_mask_from_init,
_test_mask_from_regridder,
Expand Down Expand Up @@ -81,3 +82,8 @@ def test_dtype_handling(src_tgt_types, in_dtype):
"""Test regridding scheme handles dtype as expected."""
src_type, tgt_type = src_tgt_types
_test_dtype_handling(ESMFAreaWeighted, src_type, tgt_type, in_dtype)


def test_esmf_args():
"""Test regridding scheme handles esmf_args as expected."""
_test_esmf_args(ESMFAreaWeighted)
5 changes: 5 additions & 0 deletions esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from esmf_regrid.tests.unit.schemes.__init__ import (
_test_cube_regrid,
_test_dtype_handling,
_test_esmf_args,
_test_invalid_mdtol,
_test_mask_from_init,
_test_mask_from_regridder,
Expand Down Expand Up @@ -70,3 +71,7 @@ def test_dtype_handling(src_tgt_types, in_dtype):
"""Test regridding scheme handles dtype as expected."""
src_type, tgt_type = src_tgt_types
_test_dtype_handling(ESMFBilinear, src_type, tgt_type, in_dtype)

def test_esmf_args():
"""Test regridding scheme handles esmf_args as expected."""
_test_esmf_args(ESMFBilinear)
34 changes: 34 additions & 0 deletions esmf_regrid/tests/unit/schemes/test_ESMFBilinearRegridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_curvilinear_cube,
_grid_cube,
)
from esmf_regrid import esmpy


def test_dim_switching():
Expand Down Expand Up @@ -329,3 +330,36 @@ def test_regrid_data():
)
result = rg(src)
np.testing.assert_allclose(expected_data, result.data)

def test_extrapolate_gaps():
n_lons = 6
n_lats = 5
src_lon_bounds = (-140, 180)
tgt_lon_bounds = (-180, 180)
src_lat_bounds = (-80, 90)
tgt_lat_bounds = (-90, 90)
src = _grid_cube(n_lons, n_lats, src_lon_bounds, src_lat_bounds, circular=False)
src.data = np.arange(n_lons * n_lats).reshape(n_lats, n_lons)
tgt = _grid_cube(n_lons, n_lats, tgt_lon_bounds, tgt_lat_bounds, circular=True)

extrapolate_regridder = ESMFBilinearRegridder(
src, tgt, extrapolate_gaps=True
)
normal_regridder = ESMFBilinearRegridder(
src, tgt, extrapolate_gaps=False
)

extrapolate_result = extrapolate_regridder(src)
normal_result = normal_regridder(src)

assert not np.ma.is_masked(extrapolate_result.data)
assert np.ma.is_masked(normal_result.data)

expected_args = {
"extrap_method": esmpy.ExtrapMethod.NEAREST_IDAVG,
"extrap_num_src_pnts": 2,
"extrap_dist_exponent": 1,
}
assert extrapolate_regridder.esmf_args == expected_args
assert normal_regridder.esmf_args == {}

6 changes: 6 additions & 0 deletions esmf_regrid/tests/unit/schemes/test_ESMFNearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from esmf_regrid.schemes import ESMFNearest
from esmf_regrid.tests.unit.schemes.__init__ import (
_test_dtype_handling,
_test_esmf_args,
_test_mask_from_init,
_test_mask_from_regridder,
_test_non_degree_crs,
Expand Down Expand Up @@ -113,3 +114,8 @@ def test_dtype_handling(src_tgt_types, in_dtype):
"""Test regridding scheme handles dtype as expected."""
src_type, tgt_type = src_tgt_types
_test_dtype_handling(ESMFNearest, src_type, tgt_type, in_dtype)


def test_esmf_args():
"""Test regridding scheme handles esmf_args as expected."""
_test_esmf_args(ESMFNearest)

0 comments on commit 2fb00e3

Please sign in to comment.