Skip to content

Commit

Permalink
Fix unmasked connectivity bug (SciTools#385)
Browse files Browse the repository at this point in the history
* fix unmasked connectivity bug

* adapt to iris mesh API change

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* adapt to iris mesh API change

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* adapt to iris mesh API change

* flake8

* flake8

* normalise esmf field data

* Update esmf_regrid/experimental/unstructured_regrid.py

Co-authored-by: Patrick Peglar <[email protected]>

* fix tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Patrick Peglar <[email protected]>
  • Loading branch information
3 people authored Jul 23, 2024
1 parent 1772591 commit 5354829
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 28 deletions.
11 changes: 8 additions & 3 deletions esmf_regrid/experimental/unstructured_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,14 @@ def _as_esmf_info(self):
nodeCoord = self.node_coords.flatten()
nodeOwner = np.zeros([num_node]) # regridding currently serial
elemId = np.arange(1, num_elem + 1)
elemType = self.fnc.count(axis=1)
# Experiments seem to indicate that ESMF is using 0 indexing here
elemConn = self.fnc.compressed() - self.nsi
if np.ma.isMaskedArray(self.fnc):
elemType = self.fnc.count(axis=1)
# Experiments seem to indicate that ESMF is using 0 indexing here
elemConn = self.fnc.compressed() - self.nsi
else:
elemType = np.full(self.fnc.shape[:1], self.fnc.shape[1])
# Experiments seem to indicate that ESMF is using 0 indexing here
elemConn = self.fnc.flatten() - self.nsi
elemCoord = self.elem_coords
result = (
num_node,
Expand Down
15 changes: 11 additions & 4 deletions esmf_regrid/experimental/unstructured_scheme.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Provides an iris interface for unstructured regridding."""

from iris.experimental.ugrid import Mesh
try:
from iris.experimental.ugrid import MeshXY
except ImportError as exc:
# Prior to v3.10.0, `MeshXY` could was named `Mesh`.
try:
from iris.experimental.ugrid import Mesh as MeshXY
except ImportError:
raise exc

from esmf_regrid import check_method, Constants
from esmf_regrid.schemes import (
Expand Down Expand Up @@ -288,9 +295,9 @@ def __init__(
----------
src : :class:`iris.cube.Cube`
The rectilinear :class:`~iris.cube.Cube` cube providing the source grid.
tgt : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.Mesh`
tgt : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.MeshXY`
The unstructured :class:`~iris.cube.Cube`or
:class:`~iris.experimental.ugrid.Mesh` providing the target mesh.
:class:`~iris.experimental.ugrid.MeshXY` providing the target mesh.
mdtol : float, optional
Tolerance of missing data. The value returned in each element of
the returned array will be masked if the fraction of masked data
Expand Down Expand Up @@ -330,7 +337,7 @@ def __init__(
or ``tgt`` respectively are not constant over non-horizontal dimensions.
"""
if not isinstance(tgt, Mesh) and tgt.mesh is None:
if not isinstance(tgt, MeshXY) and tgt.mesh is None:
raise ValueError("tgt has no mesh.")
super().__init__(
src,
Expand Down
44 changes: 26 additions & 18 deletions esmf_regrid/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,17 @@
import iris.coords
import iris.cube
from iris.exceptions import CoordinateNotFoundError
from iris.experimental.ugrid import Mesh
import numpy as np

try:
from iris.experimental.ugrid import MeshXY
except ImportError as exc:
# Prior to v3.10.0, `MeshXY` could was named `Mesh`.
try:
from iris.experimental.ugrid import Mesh as MeshXY
except ImportError:
raise exc

from esmf_regrid import check_method, Constants
from esmf_regrid.esmf_regridder import GridInfo, RefinedGridInfo, Regridder
from esmf_regrid.experimental.unstructured_regrid import MeshInfo
Expand Down Expand Up @@ -38,7 +46,7 @@ def _get_mask(cube_or_mesh, use_mask=True):
if use_mask is False:
result = None
elif use_mask is True:
if isinstance(cube_or_mesh, Mesh):
if isinstance(cube_or_mesh, MeshXY):
result = None
else:
cube = cube_or_mesh
Expand Down Expand Up @@ -480,7 +488,7 @@ def _make_gridinfo(cube, method, resolution, mask):

def _make_meshinfo(cube_or_mesh, method, mask, src_or_tgt, location=None):
method = check_method(method)
if isinstance(cube_or_mesh, Mesh):
if isinstance(cube_or_mesh, MeshXY):
mesh = cube_or_mesh
else:
mesh = cube_or_mesh.mesh
Expand Down Expand Up @@ -699,7 +707,7 @@ def _regrid_rectilinear_to_unstructured__prepare(
"""
grid_x = _get_coord(src_grid_cube, "x")
grid_y = _get_coord(src_grid_cube, "y")
if isinstance(tgt_cube_or_mesh, Mesh):
if isinstance(tgt_cube_or_mesh, MeshXY):
mesh = tgt_cube_or_mesh
location = tgt_location
else:
Expand Down Expand Up @@ -795,7 +803,7 @@ def _regrid_unstructured_to_unstructured__prepare(
The 'regrid info' returned can be re-used over many 2d slices.
"""
if isinstance(tgt_cube_or_mesh, Mesh):
if isinstance(tgt_cube_or_mesh, MeshXY):
mesh = tgt_cube_or_mesh
location = tgt_location
else:
Expand Down Expand Up @@ -997,9 +1005,9 @@ def regridder(
----------
src_grid : :class:`iris.cube.Cube`
The :class:`~iris.cube.Cube` defining the source.
tgt_grid : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.Mesh`
tgt_grid : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.MeshXY`
The unstructured :class:`~iris.cube.Cube`or
:class:`~iris.experimental.ugrid.Mesh` defining the target.
:class:`~iris.experimental.ugrid.MeshXY` defining the target.
src_resolution, tgt_resolution : int, optional
If present, represents the amount of latitude slices per source/target cell
given to ESMF for calculation. If resolution is set, ``src`` and ``tgt``
Expand Down Expand Up @@ -1109,9 +1117,9 @@ def regridder(
----------
src_grid : :class:`iris.cube.Cube`
The :class:`~iris.cube.Cube` defining the source.
tgt_grid : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.Mesh`
tgt_grid : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.MeshXY`
The unstructured :class:`~iris.cube.Cube`or
:class:`~iris.experimental.ugrid.Mesh` defining the target.
:class:`~iris.experimental.ugrid.MeshXY` defining the target.
use_src_mask : :obj:`~numpy.typing.ArrayLike` or bool, optional
Array describing which elements :mod:`esmpy` will ignore on the src_grid.
If True, the mask will be derived from src_grid.
Expand Down Expand Up @@ -1217,9 +1225,9 @@ def regridder(
----------
src_grid : :class:`iris.cube.Cube`
The :class:`~iris.cube.Cube` defining the source.
tgt_grid : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.Mesh`
tgt_grid : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.MeshXY`
The unstructured :class:`~iris.cube.Cube`or
:class:`~iris.experimental.ugrid.Mesh` defining the target.
:class:`~iris.experimental.ugrid.MeshXY` defining the target.
use_src_mask : :obj:`~numpy.typing.ArrayLike` or bool, optional
Array describing which elements :mod:`esmpy` will ignore on the src_grid.
If True, the mask will be derived from src_grid.
Expand Down Expand Up @@ -1280,7 +1288,7 @@ def __init__(
----------
src : :class:`iris.cube.Cube`
The rectilinear :class:`~iris.cube.Cube` providing the source grid.
tgt : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.Mesh`
tgt : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.MeshXY`
The rectilinear :class:`~iris.cube.Cube` providing the target grid.
method : :class:`Constants.Method`
The method to be used to calculate weights.
Expand Down Expand Up @@ -1325,7 +1333,7 @@ def __init__(
kwargs["tgt_mask"] = self.tgt_mask

src_is_mesh = src.mesh is not None
tgt_is_mesh = isinstance(tgt, Mesh) or tgt.mesh is not None
tgt_is_mesh = isinstance(tgt, MeshXY) or tgt.mesh is not None
if src_is_mesh:
if tgt_is_mesh:
prepare_func = _regrid_unstructured_to_unstructured__prepare
Expand Down Expand Up @@ -1457,9 +1465,9 @@ def __init__(
----------
src : :class:`iris.cube.Cube`
The rectilinear :class:`~iris.cube.Cube` providing the source.
tgt : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.Mesh`
tgt : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.MeshXY`
The unstructured :class:`~iris.cube.Cube`or
:class:`~iris.experimental.ugrid.Mesh` defining the target.
:class:`~iris.experimental.ugrid.MeshXY` defining the target.
mdtol : float, default=0
Tolerance of missing data. The value returned in each element of
the returned array will be masked if the fraction of masked data
Expand Down Expand Up @@ -1537,7 +1545,7 @@ def __init__(
The rectilinear :class:`~iris.cube.Cube` providing the source.
tgt : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.Mesh`
The unstructured :class:`~iris.cube.Cube`or
:class:`~iris.experimental.ugrid.Mesh` defining the target.
:class:`~iris.experimental.ugrid.MeshXY` defining the target.
mdtol : float, default=0
Tolerance of missing data. The value returned in each element of
the returned array will be masked if the fraction of masked data
Expand Down Expand Up @@ -1594,9 +1602,9 @@ def __init__(
----------
src : :class:`iris.cube.Cube`
The rectilinear :class:`~iris.cube.Cube` providing the source.
tgt : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.Mesh`
tgt : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.MeshXY`
The unstructured :class:`~iris.cube.Cube`or
:class:`~iris.experimental.ugrid.Mesh` defining the target.
:class:`~iris.experimental.ugrid.MeshXY` defining the target.
precomputed_weights : :class:`scipy.sparse.spmatrix`, optional
If ``None``, :mod:`esmpy` will be used to
calculate regridding weights. Otherwise, :mod:`esmpy` will be bypassed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ def test_make_mesh():
assert esmf_mesh_0.__repr__() == esmf_mesh_1.__repr__() == expected_repr


def test_connectivity_mask_equivalence():
"""Test for handling connectivity masks :meth:`~esmf_regrid.esmf_regridder.GridInfo.make_esmf_field`."""
coords, nodes, _ = _make_small_mesh_args()
coords = coords[:-1]
nodes = nodes[:, :-1]
unmasked_nodes = nodes.filled()
mesh = MeshInfo(coords, unmasked_nodes, 0)
esmf_mesh_unmasked = mesh.make_esmf_field()
esmf_mesh_unmasked.data[:] = 0

mesh = MeshInfo(coords, nodes, 0)
esmf_mesh_masked = mesh.make_esmf_field()
esmf_mesh_masked.data[:] = 0
assert esmf_mesh_unmasked.__repr__() == esmf_mesh_masked.__repr__()


def test_regrid_with_mesh():
"""Basic test for regridding with :meth:`~esmf_regrid.esmf_regridder.GridInfo.make_esmf_field`."""
mesh_args = _make_small_mesh_args()
Expand Down
15 changes: 12 additions & 3 deletions esmf_regrid/tests/unit/schemes/test__mesh_to_MeshInfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,20 @@

from iris.coords import AuxCoord
from iris.cube import Cube
from iris.experimental.ugrid import Connectivity, Mesh
from iris.experimental.ugrid import Connectivity
import numpy as np
from numpy import ma
import scipy.sparse

try:
from iris.experimental.ugrid import MeshXY
except ImportError as exc:
# Prior to v3.10.0, `MeshXY` could was named `Mesh`.
try:
from iris.experimental.ugrid import Mesh as MeshXY
except ImportError:
raise exc

from esmf_regrid.esmf_regridder import Regridder
from esmf_regrid.schemes import _mesh_to_MeshInfo

Expand Down Expand Up @@ -69,7 +78,7 @@ def _example_mesh():
lat_values = [60, -60, -60, 60, 10, 0]
lons = AuxCoord(lon_values, standard_name="longitude")
lats = AuxCoord(lat_values, standard_name="latitude")
mesh = Mesh(2, ((lons, "x"), (lats, "y")), fnc)
mesh = MeshXY(2, ((lons, "x"), (lats, "y")), fnc)
return mesh


Expand Down Expand Up @@ -170,7 +179,7 @@ def _gridlike_mesh(n_lons, n_lats, nsi=0):
)
lons = AuxCoord(node_lons, standard_name="longitude")
lats = AuxCoord(node_lats, standard_name="latitude")
mesh = Mesh(2, ((lons, "x"), (lats, "y")), [fnc, enc])
mesh = MeshXY(2, ((lons, "x"), (lats, "y")), [fnc, enc])

# In order to add a mesh to a cube, face locations must be added.
face_lon_coord = AuxCoord(face_lons, standard_name="longitude")
Expand Down

0 comments on commit 5354829

Please sign in to comment.