Skip to content

Commit

Permalink
clarify and test tgt_location defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenworsley committed Aug 21, 2023
1 parent 20d940b commit e860743
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 9 deletions.
2 changes: 1 addition & 1 deletion esmf_regrid/experimental/unstructured_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def __init__(
a boolean value. If True, this array is taken from the mask on the data
in ``tgt``. If False, no mask will be taken and all points
will be used in weights calculation.
tgt_location : str or None, default="face"
tgt_location : str or None, default=None
Either "face" or "node". Describes the location for data on the mesh
if the target is not a :class:`~iris.cube.Cube`.
Expand Down
22 changes: 14 additions & 8 deletions esmf_regrid/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ class ESMFAreaWeighted:
"""

def __init__(
self, mdtol=0, use_src_mask=False, use_tgt_mask=False, tgt_location=None
self, mdtol=0, use_src_mask=False, use_tgt_mask=False, tgt_location="face"
):
"""
Area-weighted scheme for regridding between rectilinear grids.
Expand All @@ -870,18 +870,22 @@ def __init__(
use_tgt_mask : bool, default=False
If True, derive a mask from target cube which will tell :mod:`esmpy`
which points to ignore.
tgt_location : str or None, default=None
tgt_location : str or None, default="face"
Either "face" or "node". Describes the location for data on the mesh
if the target is not a :class:`~iris.cube.Cube`.
"""
if not (0 <= mdtol <= 1):
msg = "Value for mdtol must be in range 0 - 1, got {}."
raise ValueError(msg.format(mdtol))
if tgt_location is not None and tgt_location != "face":
raise ValueError(
"For area weighted regridding, target location must be 'face'."
)
self.mdtol = mdtol
self.use_src_mask = use_src_mask
self.use_tgt_mask = use_tgt_mask
self.tgt_location = tgt_location
self.tgt_location = "face"

def __repr__(self):
"""Return a representation of the class."""
Expand All @@ -893,7 +897,7 @@ def regridder(
tgt_grid,
use_src_mask=None,
use_tgt_mask=None,
tgt_location=None,
tgt_location="face",
):
"""
Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``.
Expand All @@ -911,7 +915,7 @@ def regridder(
use_tgt_mask : :obj:`~numpy.typing.ArrayLike` or bool, optional
Array describing which elements :mod:`esmpy` will ignore on the tgt_grid.
If True, the mask will be derived from tgt_grid.
tgt_location : str or None, default=None
tgt_location : str or None, default="face"
Either "face" or "node". Describes the location for data on the mesh
if the target is not a :class:`~iris.cube.Cube`.
Expand All @@ -934,15 +938,17 @@ def regridder(
use_src_mask = self.use_src_mask
if use_tgt_mask is None:
use_tgt_mask = self.use_tgt_mask
if tgt_location is None:
tgt_location = self.tgt_location
if tgt_location is not None and tgt_location != "face":
raise ValueError(
"For area weighted regridding, target location must be 'face'."
)
return ESMFAreaWeightedRegridder(
src_grid,
tgt_grid,
mdtol=self.mdtol,
use_src_mask=use_src_mask,
use_tgt_mask=use_tgt_mask,
tgt_location=tgt_location,
tgt_location="face",
)


Expand Down
11 changes: 11 additions & 0 deletions esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,14 @@ def test_mask_from_regridder(mask_keyword):
Checks that use_src_mask and use_tgt_mask are passed down correctly.
"""
_test_mask_from_regridder(ESMFAreaWeighted, mask_keyword)


def test_invalid_tgt_location():
"""
Test initialisation of :class:`esmf_regrid.schemes.ESMFAreaWeighted`.
Checks that initialisation fails when tgt_location is not "face".
"""
match = "For area weighted regridding, target location must be 'face'."
with pytest.raises(ValueError, match=match):
_ = ESMFAreaWeighted(tgt_location="node")
18 changes: 18 additions & 0 deletions esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,24 @@ def test_invalid_mdtol():
_ = ESMFAreaWeightedRegridder(src, tgt, mdtol=-1)


def test_invalid_tgt_location():
"""
Test initialisation of :class:`esmf_regrid.schemes.ESMFAreaWeightedRegridder`.
Checks that initialisation fails when tgt_location is not "face".
"""
n_lons = 6
n_lats = 5
lon_bounds = (-180, 180)
lat_bounds = (-90, 90)
src = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True)
tgt = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True)

match = "For area weighted regridding, target location must be 'face'."
with pytest.raises(ValueError, match=match):
_ = ESMFAreaWeightedRegridder(src, tgt, tgt_location="node")


def test_curvilinear_equivalence():
"""
Test initialisation of :class:`esmf_regrid.schemes.ESMFAreaWeightedRegridder`.
Expand Down

0 comments on commit e860743

Please sign in to comment.