From e86074339196b51a12b76e3fa33c19de496f66db Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Mon, 21 Aug 2023 15:26:19 +0100 Subject: [PATCH] clarify and test tgt_location defaults --- .../experimental/unstructured_scheme.py | 2 +- esmf_regrid/schemes.py | 22 ++++++++++++------- .../unit/schemes/test_ESMFAreaWeighted.py | 11 ++++++++++ .../schemes/test_ESMFAreaWeightedRegridder.py | 18 +++++++++++++++ 4 files changed, 44 insertions(+), 9 deletions(-) diff --git a/esmf_regrid/experimental/unstructured_scheme.py b/esmf_regrid/experimental/unstructured_scheme.py index 48996200..f17aadcd 100644 --- a/esmf_regrid/experimental/unstructured_scheme.py +++ b/esmf_regrid/experimental/unstructured_scheme.py @@ -326,7 +326,7 @@ def __init__( a boolean value. If True, this array is taken from the mask on the data in ``tgt``. If False, no mask will be taken and all points will be used in weights calculation. - tgt_location : str or None, default="face" + tgt_location : str or None, default=None Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. diff --git a/esmf_regrid/schemes.py b/esmf_regrid/schemes.py index 780059dd..a9c2528e 100644 --- a/esmf_regrid/schemes.py +++ b/esmf_regrid/schemes.py @@ -849,7 +849,7 @@ class ESMFAreaWeighted: """ def __init__( - self, mdtol=0, use_src_mask=False, use_tgt_mask=False, tgt_location=None + self, mdtol=0, use_src_mask=False, use_tgt_mask=False, tgt_location="face" ): """ Area-weighted scheme for regridding between rectilinear grids. @@ -870,7 +870,7 @@ def __init__( use_tgt_mask : bool, default=False If True, derive a mask from target cube which will tell :mod:`esmpy` which points to ignore. - tgt_location : str or None, default=None + tgt_location : str or None, default="face" Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. @@ -878,10 +878,14 @@ def __init__( if not (0 <= mdtol <= 1): msg = "Value for mdtol must be in range 0 - 1, got {}." raise ValueError(msg.format(mdtol)) + if tgt_location is not None and tgt_location != "face": + raise ValueError( + "For area weighted regridding, target location must be 'face'." + ) self.mdtol = mdtol self.use_src_mask = use_src_mask self.use_tgt_mask = use_tgt_mask - self.tgt_location = tgt_location + self.tgt_location = "face" def __repr__(self): """Return a representation of the class.""" @@ -893,7 +897,7 @@ def regridder( tgt_grid, use_src_mask=None, use_tgt_mask=None, - tgt_location=None, + tgt_location="face", ): """ Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -911,7 +915,7 @@ def regridder( use_tgt_mask : :obj:`~numpy.typing.ArrayLike` or bool, optional Array describing which elements :mod:`esmpy` will ignore on the tgt_grid. If True, the mask will be derived from tgt_grid. - tgt_location : str or None, default=None + tgt_location : str or None, default="face" Either "face" or "node". Describes the location for data on the mesh if the target is not a :class:`~iris.cube.Cube`. @@ -934,15 +938,17 @@ def regridder( use_src_mask = self.use_src_mask if use_tgt_mask is None: use_tgt_mask = self.use_tgt_mask - if tgt_location is None: - tgt_location = self.tgt_location + if tgt_location is not None and tgt_location != "face": + raise ValueError( + "For area weighted regridding, target location must be 'face'." + ) return ESMFAreaWeightedRegridder( src_grid, tgt_grid, mdtol=self.mdtol, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, - tgt_location=tgt_location, + tgt_location="face", ) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py index 55127f36..49cd91c0 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py @@ -51,3 +51,14 @@ def test_mask_from_regridder(mask_keyword): Checks that use_src_mask and use_tgt_mask are passed down correctly. """ _test_mask_from_regridder(ESMFAreaWeighted, mask_keyword) + + +def test_invalid_tgt_location(): + """ + Test initialisation of :class:`esmf_regrid.schemes.ESMFAreaWeighted`. + + Checks that initialisation fails when tgt_location is not "face". + """ + match = "For area weighted regridding, target location must be 'face'." + with pytest.raises(ValueError, match=match): + _ = ESMFAreaWeighted(tgt_location="node") diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py index 86bcbf8d..c5d77ae9 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py @@ -88,6 +88,24 @@ def test_invalid_mdtol(): _ = ESMFAreaWeightedRegridder(src, tgt, mdtol=-1) +def test_invalid_tgt_location(): + """ + Test initialisation of :class:`esmf_regrid.schemes.ESMFAreaWeightedRegridder`. + + Checks that initialisation fails when tgt_location is not "face". + """ + n_lons = 6 + n_lats = 5 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + src = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) + tgt = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) + + match = "For area weighted regridding, target location must be 'face'." + with pytest.raises(ValueError, match=match): + _ = ESMFAreaWeightedRegridder(src, tgt, tgt_location="node") + + def test_curvilinear_equivalence(): """ Test initialisation of :class:`esmf_regrid.schemes.ESMFAreaWeightedRegridder`.