Skip to content

Commit

Permalink
Allow regridding for projections in non-degree type units (#178)
Browse files Browse the repository at this point in the history
* allow regridding for projections in non-degree type units

* fix tests

* add tests

* fix tests

* address review comments
  • Loading branch information
stephenworsley authored Nov 1, 2023
1 parent e61fed6 commit 4f685cf
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 8 deletions.
11 changes: 7 additions & 4 deletions esmf_regrid/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,13 @@ def _cube_to_GridInfo(cube, center=False, resolution=None, mask=None):
lat_bound_array = _contiguous_masked(lat.bounds, mask)
# 2D coords must be AuxCoords, which do not have a circular attribute.
circular = False
lon_bound_array = lon.units.convert(lon_bound_array, Unit("degrees"))
lat_bound_array = lat.units.convert(lat_bound_array, Unit("degrees"))
lon_points = lon.units.convert(lon.points, Unit("degrees"))
lat_points = lon.units.convert(lat.points, Unit("degrees"))
lon_points = lon.points
lat_points = lat.points
if crs is None:
lon_bound_array = lon.units.convert(lon_bound_array, Unit("degrees"))
lat_bound_array = lat.units.convert(lat_bound_array, Unit("degrees"))
lon_points = lon.units.convert(lon_points, Unit("degrees"))
lat_points = lon.units.convert(lat_points, Unit("degrees"))
if resolution is None:
grid_info = GridInfo(
lon_points,
Expand Down
49 changes: 49 additions & 0 deletions esmf_regrid/tests/unit/schemes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Unit tests for `esmf_regrid.schemes`."""

from iris.coord_systems import OSGB
import numpy as np
from numpy import ma
import pytest

from esmf_regrid.schemes import ESMFAreaWeighted, ESMFBilinear, ESMFNearest
from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import _grid_cube
from esmf_regrid.tests.unit.schemes.test__mesh_to_MeshInfo import (
_gridlike_mesh,
Expand Down Expand Up @@ -166,3 +168,50 @@ def _test_mask_from_regridder(scheme, mask_keyword):
np.testing.assert_allclose(
getattr(rg_from_different, regridder_attr), mask_different
)


def _test_non_degree_crs(scheme):
"""Test regridding scheme is compatible with coordinates with non-degree units."""
coord_system = OSGB()

# This definition comes from a small section of real user data.
n_lons_src = 2
n_lats_src = 3
lon_bounds = (-197500, -192500)
lat_bounds = (1247500, 1237500)
tm_cube = _grid_cube(
n_lons_src,
n_lats_src,
lon_bounds,
lat_bounds,
circular=False,
coord_system=coord_system,
standard_names=["projection_x_coordinate", "projection_y_coordinate"],
units="m",
)
data = np.arange(n_lats_src * n_lons_src).reshape([n_lats_src, n_lons_src])
tm_cube.data = data

n_lons_tgt = 12
n_lats_tgt = 14
lon_bounds_tgt = (-13, -12.8)
lat_bounds_tgt = (60.5, 60.7)
cube_tgt = _grid_cube(
n_lons_tgt, n_lats_tgt, lon_bounds_tgt, lat_bounds_tgt, circular=True
)

result = tm_cube.regrid(cube_tgt, scheme())

# Set expected results, this varies depending on the scheme.
if scheme is ESMFAreaWeighted:
expected_sum, expected_unmasked = 50.86147272655136, 21
elif scheme is ESMFBilinear:
expected_sum, expected_unmasked = 35.90837983047451, 13
elif scheme is ESMFNearest:
expected_sum, expected_unmasked = 490, 168

# Check that the data is as expected.
assert np.isclose(result.data.sum(), expected_sum)

# Check that the number of masked points is as expected.
assert (1 - result.data.mask).sum() == expected_unmasked
6 changes: 6 additions & 0 deletions esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
_test_invalid_mdtol,
_test_mask_from_init,
_test_mask_from_regridder,
_test_non_degree_crs,
)


Expand Down Expand Up @@ -62,3 +63,8 @@ def test_invalid_tgt_location():
match = "For area weighted regridding, target location must be 'face'."
with pytest.raises(ValueError, match=match):
_ = ESMFAreaWeighted(tgt_location="node")


def test_non_degree_crs():
"""Test for coordinates with non-degree units."""
_test_non_degree_crs(ESMFAreaWeighted)
6 changes: 6 additions & 0 deletions esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
_test_invalid_mdtol,
_test_mask_from_init,
_test_mask_from_regridder,
_test_non_degree_crs,
)


Expand Down Expand Up @@ -51,3 +52,8 @@ def test_mask_from_regridder(mask_keyword):
Checks that use_src_mask and use_tgt_mask are passed down correctly.
"""
_test_mask_from_regridder(ESMFBilinear, mask_keyword)


def test_non_degree_crs():
"""Test for coordinates with non-degree units."""
_test_non_degree_crs(ESMFBilinear)
6 changes: 6 additions & 0 deletions esmf_regrid/tests/unit/schemes/test_ESMFNearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from esmf_regrid.tests.unit.schemes.__init__ import (
_test_mask_from_init,
_test_mask_from_regridder,
_test_non_degree_crs,
)
from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import _grid_cube
from esmf_regrid.tests.unit.schemes.test__mesh_to_MeshInfo import (
Expand Down Expand Up @@ -94,3 +95,8 @@ def test_mask_from_regridder(mask_keyword):
Checks that use_src_mask and use_tgt_mask are passed down correctly.
"""
_test_mask_from_regridder(ESMFNearest, mask_keyword)


def test_non_degree_crs():
"""Test for coordinates with non-degree units."""
_test_non_degree_crs(ESMFNearest)
10 changes: 6 additions & 4 deletions esmf_regrid/tests/unit/schemes/test__cube_to_GridInfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,23 @@ def _grid_cube(
lat_outer_bounds,
circular=False,
coord_system=None,
standard_names=["longitude", "latitude"],
units="degrees",
):
lon_points, lon_bounds = _generate_points_and_bounds(n_lons, lon_outer_bounds)
lon = DimCoord(
lon_points,
"longitude",
units="degrees",
standard_names[0],
units=units,
bounds=lon_bounds,
circular=circular,
coord_system=coord_system,
)
lat_points, lat_bounds = _generate_points_and_bounds(n_lats, lat_outer_bounds)
lat = DimCoord(
lat_points,
"latitude",
units="degrees",
standard_names[1],
units=units,
bounds=lat_bounds,
coord_system=coord_system,
)
Expand Down

0 comments on commit 4f685cf

Please sign in to comment.