From e5ee8a3f49af96bd7e629356f6e2f0d131c36b7d Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Wed, 1 May 2024 16:24:30 +0100 Subject: [PATCH 1/9] extend regridder saving/loading --- esmf_regrid/experimental/io.py | 160 +++++++++++++++--- esmf_regrid/schemes.py | 2 + .../experimental/io/test_round_tripping.py | 79 +++++++-- 3 files changed, 200 insertions(+), 41 deletions(-) diff --git a/esmf_regrid/experimental/io.py b/esmf_regrid/experimental/io.py index 624820c3..efa4117a 100644 --- a/esmf_regrid/experimental/io.py +++ b/esmf_regrid/experimental/io.py @@ -1,5 +1,7 @@ """Provides load/save functions for regridders.""" +from contextlib import contextmanager + import iris from iris.coords import AuxCoord from iris.cube import Cube, CubeList @@ -13,9 +15,19 @@ GridToMeshESMFRegridder, MeshToGridESMFRegridder, ) +from esmf_regrid.schemes import ( + ESMFAreaWeightedRegridder, + ESMFBilinearRegridder, + ESMFNearestRegridder, + MeshRecord, + GridRecord, +) SUPPORTED_REGRIDDERS = [ + ESMFAreaWeightedRegridder, + ESMFBilinearRegridder, + ESMFNearestRegridder, GridToMeshESMFRegridder, MeshToGridESMFRegridder, ] @@ -34,6 +46,8 @@ MDTOL = "mdtol" METHOD = "method" RESOLUTION = "resolution" +SOURCE_RESOLUTION = "src_resolution" +TARGET_RESOLUTION = "tgt_resolution" def _add_mask_to_cube(mask, cube, name): @@ -43,6 +57,49 @@ def _add_mask_to_cube(mask, cube, name): cube.add_aux_coord(mask_coord, list(range(cube.ndim))) +@contextmanager +def managed_var_name(src_cube, tgt_cube): + src_coord_names = [] + src_mesh_coords = [] + if src_cube.mesh is not None: + src_mesh = src_cube.mesh + src_mesh_coords = src_mesh.coords() + for coord in src_mesh_coords: + src_coord_names.append(coord.var_name) + tgt_coord_names = [] + tgt_mesh_coords = [] + if tgt_cube.mesh is not None: + tgt_mesh = tgt_cube.mesh + tgt_mesh_coords = tgt_mesh.coords() + for coord in tgt_mesh_coords: + tgt_coord_names.append(coord.var_name) + + try: + for coord in src_mesh_coords: + coord.var_name = "_".join([SOURCE_NAME, "mesh", coord.name()]) + for coord in tgt_mesh_coords: + coord.var_name = "_".join([TARGET_NAME, "mesh", coord.name()]) + yield None + finally: + for coord, var_name in zip(src_mesh_coords, src_coord_names): + coord.var_name = var_name + for coord, var_name in zip(tgt_mesh_coords, tgt_coord_names): + coord.var_name = var_name + + +def _clean_var_names(cube): + cube.var_name = None + for coord in cube.coords(): + coord.var_name = None + if cube.mesh is not None: + cube.mesh.var_name = None + for coord in cube.mesh.coords(): + coord.var_name = None + for con in cube.mesh.connectivities(): + con.var_name = None + return cube + + def save_regridder(rg, filename): """ Save a regridder scheme instance. @@ -76,28 +133,52 @@ def _standard_grid_cube(grid, name): cube.add_aux_coord(grid[1], [0, 1]) return cube - if regridder_type == "GridToMeshESMFRegridder": + def _standard_mesh_cube(mesh, location, name): + mesh_coords = mesh.to_MeshCoords(location) + data = np.zeros(mesh_coords[0].points.shape[0]) + cube = Cube(data, var_name=name, long_name=name) + for coord in mesh_coords: + cube.add_aux_coord(coord, 0) + return cube + + if regridder_type in [ + "ESMFAreaWeightedRegridder", + "ESMFBilinearRegridder", + "ESMFNearestRegridder", + ]: + src_grid = rg._src + if isinstance(src_grid, GridRecord): + src_cube = _standard_grid_cube(src_grid, SOURCE_NAME) + elif isinstance(src_grid, MeshRecord): + src_mesh, src_location = src_grid + src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME) + else: + raise ValueError("Improper type for `rg._src`.") + _add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME) + + tgt_grid = rg._tgt + if isinstance(tgt_grid, GridRecord): + tgt_cube = _standard_grid_cube(tgt_grid, TARGET_NAME) + elif isinstance(tgt_grid, MeshRecord): + tgt_mesh, tgt_location = tgt_grid + tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME) + else: + raise ValueError("Improper type for `rg._tgt`.") + _add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME) + elif regridder_type == "GridToMeshESMFRegridder": src_grid = (rg.grid_y, rg.grid_x) src_cube = _standard_grid_cube(src_grid, SOURCE_NAME) _add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME) tgt_mesh = rg.mesh tgt_location = rg.location - tgt_mesh_coords = tgt_mesh.to_MeshCoords(tgt_location) - tgt_data = np.zeros(tgt_mesh_coords[0].points.shape[0]) - tgt_cube = Cube(tgt_data, var_name=TARGET_NAME, long_name=TARGET_NAME) - for coord in tgt_mesh_coords: - tgt_cube.add_aux_coord(coord, 0) + tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME) _add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME) elif regridder_type == "MeshToGridESMFRegridder": src_mesh = rg.mesh src_location = rg.location - src_mesh_coords = src_mesh.to_MeshCoords(src_location) - src_data = np.zeros(src_mesh_coords[0].points.shape[0]) - src_cube = Cube(src_data, var_name=SOURCE_NAME, long_name=SOURCE_NAME) - for coord in src_mesh_coords: - src_cube.add_aux_coord(coord, 0) + src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME) _add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME) tgt_grid = (rg.grid_y, rg.grid_x) @@ -112,7 +193,18 @@ def _standard_grid_cube(grid, name): method = str(check_method(rg.method).name) - resolution = rg.resolution + if regridder_type in ["GridToMeshESMFRegridder", "MeshToGridESMFRegridder"]: + resolution = rg.resolution + src_resolution = None + tgt_resolution = None + elif regridder_type == "ESMFAreaWeightedRegridder": + resolution = None + src_resolution = rg.src_resolution + tgt_resolution = rg.tgt_resolution + else: + resolution = None + src_resolution = None + tgt_resolution = None weight_matrix = rg.regridder.weight_matrix reformatted_weight_matrix = scipy.sparse.coo_matrix(weight_matrix) @@ -141,6 +233,10 @@ def _standard_grid_cube(grid, name): } if resolution is not None: attributes[RESOLUTION] = resolution + if src_resolution is not None: + attributes[SOURCE_RESOLUTION] = src_resolution + if tgt_resolution is not None: + attributes[TARGET_RESOLUTION] = tgt_resolution weights_cube = Cube(weight_data, var_name=WEIGHTS_NAME, long_name=WEIGHTS_NAME) row_coord = AuxCoord( @@ -158,17 +254,14 @@ def _standard_grid_cube(grid, name): long_name=WEIGHTS_SHAPE_NAME, ) - # Avoid saving bug by placing the mesh cube second. - # TODO: simplify this when this bug is fixed in iris. - if regridder_type == "GridToMeshESMFRegridder": + # Save cubes while ensuring var_names do not conflict for the sake of consistency. + with managed_var_name(src_cube, tgt_cube): cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube]) - elif regridder_type == "MeshToGridESMFRegridder": - cube_list = CubeList([tgt_cube, src_cube, weights_cube, weight_shape_cube]) - for cube in cube_list: - cube.attributes = attributes + for cube in cube_list: + cube.attributes = attributes - iris.fileformats.netcdf.save(cube_list, filename) + iris.fileformats.netcdf.save(cube_list, filename) def load_regridder(filename): @@ -193,8 +286,8 @@ def load_regridder(filename): cubes = iris.load(filename) # Extract the source, target and metadata information. - src_cube = cubes.extract_cube(SOURCE_NAME) - tgt_cube = cubes.extract_cube(TARGET_NAME) + src_cube = _clean_var_names(cubes.extract_cube(SOURCE_NAME)) + tgt_cube = _clean_var_names(cubes.extract_cube(TARGET_NAME)) weights_cube = cubes.extract_cube(WEIGHTS_NAME) weight_shape_cube = cubes.extract_cube(WEIGHTS_SHAPE_NAME) @@ -210,8 +303,12 @@ def load_regridder(filename): ) resolution = weights_cube.attributes.get(RESOLUTION, None) + src_resolution = weights_cube.attributes.get(SOURCE_RESOLUTION, None) + tgt_resolution = weights_cube.attributes.get(TARGET_RESOLUTION, None) if resolution is not None: resolution = int(resolution) + src_resolution = int(src_resolution) + tgt_resolution = int(tgt_resolution) # Reconstruct the weight matrix. weight_data = weights_cube.data @@ -234,18 +331,25 @@ def load_regridder(filename): use_tgt_mask = False if scheme is GridToMeshESMFRegridder: - resolution_keyword = "src_resolution" + resolution_keyword = SOURCE_RESOLUTION + kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} elif scheme is MeshToGridESMFRegridder: - resolution_keyword = "tgt_resolution" + resolution_keyword = TARGET_RESOLUTION + kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} + elif scheme is ESMFAreaWeightedRegridder: + kwargs = { + SOURCE_RESOLUTION: src_resolution, + TARGET_RESOLUTION: tgt_resolution, + "mdtol": mdtol, + } + elif scheme is ESMFBilinearRegridder: + kwargs = {"mdtol": mdtol} else: - raise NotImplementedError - kwargs = {resolution_keyword: resolution} + kwargs = {} regridder = scheme( src_cube, tgt_cube, - mdtol=mdtol, - method=method, precomputed_weights=weight_matrix, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, diff --git a/esmf_regrid/schemes.py b/esmf_regrid/schemes.py index bd60e9f9..ff064f1d 100644 --- a/esmf_regrid/schemes.py +++ b/esmf_regrid/schemes.py @@ -1465,8 +1465,10 @@ def __init__( if tgt_location is not "face". """ kwargs = dict() + self.src_resolution = src_resolution if src_resolution is not None: kwargs["src_resolution"] = src_resolution + self.tgt_resolution = tgt_resolution if tgt_resolution is not None: kwargs["tgt_resolution"] = tgt_resolution if tgt_location is not None and tgt_location != "face": diff --git a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py index c16adcef..3eff1cb8 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py +++ b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py @@ -4,7 +4,7 @@ from numpy import ma import pytest -from esmf_regrid import Constants +from esmf_regrid import Constants, ESMFAreaWeighted, ESMFBilinear, ESMFNearest from esmf_regrid.experimental.io import load_regridder, save_regridder from esmf_regrid.experimental.unstructured_scheme import ( GridToMeshESMFRegridder, @@ -122,6 +122,13 @@ def _make_mesh_to_grid_regridder( return rg, src +def _compare_ignoring_var_names(x, y): + old_var_name = x.var_name + x.var_name = y.var_name + assert x == y + x.var_name = old_var_name + + @pytest.mark.parametrize( "method", [ @@ -140,8 +147,8 @@ def test_GridToMeshESMFRegridder_round_trip(tmp_path, method): assert original_rg.location == loaded_rg.location assert original_rg.method == loaded_rg.method assert original_rg.mdtol == loaded_rg.mdtol - assert original_rg.grid_x == loaded_rg.grid_x - assert original_rg.grid_y == loaded_rg.grid_y + _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) + _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) # TODO: uncomment when iris mesh comparison becomes available. # assert original_rg.mesh == loaded_rg.mesh @@ -187,8 +194,8 @@ def test_GridToMeshESMFRegridder_round_trip(tmp_path, method): nc_filename = tmp_path / "non_circular_regridder.nc" save_regridder(original_nc_rg, nc_filename) loaded_nc_rg = load_regridder(str(nc_filename)) - assert original_nc_rg.grid_x == loaded_nc_rg.grid_x - assert original_nc_rg.grid_y == loaded_nc_rg.grid_y + _compare_ignoring_var_names(original_nc_rg.grid_x, loaded_nc_rg.grid_x) + _compare_ignoring_var_names(original_nc_rg.grid_y, loaded_nc_rg.grid_y) def test_GridToMeshESMFRegridder_curvilinear_round_trip(tmp_path): @@ -198,8 +205,8 @@ def test_GridToMeshESMFRegridder_curvilinear_round_trip(tmp_path): save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) - assert original_rg.grid_x == loaded_rg.grid_x - assert original_rg.grid_y == loaded_rg.grid_y + _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) + _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) # Demonstrate regridding still gives the same results. src_data = ma.arange(np.product(src.data.shape)).reshape(src.data.shape) @@ -255,8 +262,8 @@ def test_MeshToGridESMFRegridder_round_trip(tmp_path, method): assert original_rg.location == loaded_rg.location assert original_rg.method == loaded_rg.method assert original_rg.mdtol == loaded_rg.mdtol - assert original_rg.grid_x == loaded_rg.grid_x - assert original_rg.grid_y == loaded_rg.grid_y + _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) + _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) # TODO: uncomment when iris mesh comparison becomes available. # assert original_rg.mesh == loaded_rg.mesh @@ -301,8 +308,8 @@ def test_MeshToGridESMFRegridder_round_trip(tmp_path, method): nc_filename = tmp_path / "non_circular_regridder.nc" save_regridder(original_nc_rg, nc_filename) loaded_nc_rg = load_regridder(str(nc_filename)) - assert original_nc_rg.grid_x == loaded_nc_rg.grid_x - assert original_nc_rg.grid_y == loaded_nc_rg.grid_y + _compare_ignoring_var_names(original_nc_rg.grid_x, loaded_nc_rg.grid_x) + _compare_ignoring_var_names(original_nc_rg.grid_y, loaded_nc_rg.grid_y) def test_MeshToGridESMFRegridder_curvilinear_round_trip(tmp_path): @@ -312,8 +319,8 @@ def test_MeshToGridESMFRegridder_curvilinear_round_trip(tmp_path): save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) - assert original_rg.grid_x == loaded_rg.grid_x - assert original_rg.grid_y == loaded_rg.grid_y + _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) + _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) # Demonstrate regridding still gives the same results. src_data = ma.arange(np.product(src.data.shape)).reshape(src.data.shape) @@ -323,3 +330,49 @@ def test_MeshToGridESMFRegridder_curvilinear_round_trip(tmp_path): loaded_result = loaded_rg(src).data assert np.array_equal(original_result, loaded_result) assert np.array_equal(original_result.mask, loaded_result.mask) + + +@pytest.mark.parametrize( + "src_type,tgt_type", + [ + ("grid", "grid"), + ("grid", "mesh"), + ("mesh", "grid"), + ("mesh", "mesh"), + ], +) +@pytest.mark.parametrize( + "scheme", + [ESMFAreaWeighted, ESMFBilinear, ESMFNearest], + ids=["conservative", "linear", "nearest"], +) +def test_generic_regridder(tmp_path, src_type, tgt_type, scheme): + """Test save/load round tripping for regridders in `esmf_regrid.schemes`.""" + n_lons_src = 6 + n_lons_tgt = 3 + n_lats_src = 4 + n_lats_tgt = 2 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + if src_type == "grid": + src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True) + else: + src = _gridlike_mesh_cube(n_lons_src, n_lats_src) + if tgt_type == "grid": + tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True) + elif tgt_type == "mesh": + tgt = _gridlike_mesh_cube(n_lons_tgt, n_lats_tgt) + + original_rg = scheme().regridder(src, tgt) + filename = tmp_path / "regridder.nc" + save_regridder(original_rg, filename) + loaded_rg = load_regridder(str(filename)) + + if src_type == "grid": + assert original_rg._src == loaded_rg._src + if tgt_type == "grid": + assert original_rg._tgt == loaded_rg._tgt + if scheme == ESMFAreaWeighted: + assert original_rg.src_resolution == loaded_rg.src_resolution + assert original_rg.tgt_resolution == loaded_rg.tgt_resolution + assert original_rg.mdtol == loaded_rg.mdtol From 2b7059d2a06e4458a54b0a56caf500f36c9806c3 Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Wed, 1 May 2024 16:33:19 +0100 Subject: [PATCH 2/9] fix bug --- esmf_regrid/experimental/io.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/esmf_regrid/experimental/io.py b/esmf_regrid/experimental/io.py index efa4117a..856befc4 100644 --- a/esmf_regrid/experimental/io.py +++ b/esmf_regrid/experimental/io.py @@ -307,7 +307,9 @@ def load_regridder(filename): tgt_resolution = weights_cube.attributes.get(TARGET_RESOLUTION, None) if resolution is not None: resolution = int(resolution) + if src_resolution is not None: src_resolution = int(src_resolution) + if tgt_resolution is not None: tgt_resolution = int(tgt_resolution) # Reconstruct the weight matrix. From 8c5e7b2c6b8e1bc7542d8a924ac1cd3e1520a66c Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Wed, 1 May 2024 16:38:57 +0100 Subject: [PATCH 3/9] flake8 --- esmf_regrid/experimental/io.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/esmf_regrid/experimental/io.py b/esmf_regrid/experimental/io.py index 856befc4..b36044f0 100644 --- a/esmf_regrid/experimental/io.py +++ b/esmf_regrid/experimental/io.py @@ -19,8 +19,8 @@ ESMFAreaWeightedRegridder, ESMFBilinearRegridder, ESMFNearestRegridder, - MeshRecord, GridRecord, + MeshRecord, ) @@ -58,7 +58,7 @@ def _add_mask_to_cube(mask, cube, name): @contextmanager -def managed_var_name(src_cube, tgt_cube): +def _managed_var_name(src_cube, tgt_cube): src_coord_names = [] src_mesh_coords = [] if src_cube.mesh is not None: @@ -255,7 +255,7 @@ def _standard_mesh_cube(mesh, location, name): ) # Save cubes while ensuring var_names do not conflict for the sake of consistency. - with managed_var_name(src_cube, tgt_cube): + with _managed_var_name(src_cube, tgt_cube): cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube]) for cube in cube_list: From aecdeafa5378554959773be0670db64ef67bb70d Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Tue, 28 May 2024 13:36:10 +0100 Subject: [PATCH 4/9] add tests and documentation --- docs/src/userguide/examples.rst | 4 +- docs/src/userguide/scheme_comparison.rst | 10 ++- esmf_regrid/experimental/io.py | 8 +- esmf_regrid/schemes.py | 9 +++ .../experimental/io/test_round_tripping.py | 81 ++++++++++++++++++- .../schemes/test_ESMFAreaWeightedRegridder.py | 25 ++++++ 6 files changed, 128 insertions(+), 9 deletions(-) diff --git a/docs/src/userguide/examples.rst b/docs/src/userguide/examples.rst index b66e33b7..06564656 100644 --- a/docs/src/userguide/examples.rst +++ b/docs/src/userguide/examples.rst @@ -31,10 +31,10 @@ Saving and Loading a Regridder A regridder can be set up for reuse, this saves time performing the computationally expensive initialisation process:: - from esmf_regrid.experimental.unstructured_scheme import MeshToGridESMFRegridder + from esmf_regrid.experimental.unstructured_scheme import ESMFAreaWeighted # Initialise the regridder with a source mesh and target grid. - regridder = MeshToGridESMFRegridder(source_mesh_cube, target_grid_cube) + regridder = ESMFAreaWeighted().regridder(source_mesh_cube, target_grid_cube) # use the initialised regridder to regrid the data from the source cube # onto a cube with the same grid as `target_grid_cube`. diff --git a/docs/src/userguide/scheme_comparison.rst b/docs/src/userguide/scheme_comparison.rst index e29a78d6..a9bb4a1f 100644 --- a/docs/src/userguide/scheme_comparison.rst +++ b/docs/src/userguide/scheme_comparison.rst @@ -61,10 +61,12 @@ These were formerly the only way to do regridding with a source or target cube defined on an unstructured mesh. These are less flexible and require that the source/target be defined on a grid/mesh. Unlike the above regridders whose method is fixed, these regridders take a ``method`` keyword -of ``conservative``, ``bilinear`` or ``nearest``. While most of the -functionality in these regridders have been ported into the above schemes and -regridders, these remain the only regridders capable of being saved and loaded by -:mod:`esmf_regrid.experimental.io`. +of ``conservative``, ``bilinear`` or ``nearest``. All the +functionality in these regridders has now been ported into the above schemes and +regridders. Before version 0.10, these were the only regridders capable of being +saved and loaded by :mod:`esmf_regrid.experimental.io`, so while the above generic +regridders are recomended, these regridders are still available for the sake of +consistency with regridders saved from older versions. Overview: Miscellaneous Functions diff --git a/esmf_regrid/experimental/io.py b/esmf_regrid/experimental/io.py index b36044f0..c69dc9cf 100644 --- a/esmf_regrid/experimental/io.py +++ b/esmf_regrid/experimental/io.py @@ -148,7 +148,9 @@ def _standard_mesh_cube(mesh, location, name): ]: src_grid = rg._src if isinstance(src_grid, GridRecord): - src_cube = _standard_grid_cube(src_grid, SOURCE_NAME) + src_cube = _standard_grid_cube( + (src_grid.grid_y, src_grid.grid_x), SOURCE_NAME + ) elif isinstance(src_grid, MeshRecord): src_mesh, src_location = src_grid src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME) @@ -158,7 +160,9 @@ def _standard_mesh_cube(mesh, location, name): tgt_grid = rg._tgt if isinstance(tgt_grid, GridRecord): - tgt_cube = _standard_grid_cube(tgt_grid, TARGET_NAME) + tgt_cube = _standard_grid_cube( + (tgt_grid.grid_y, tgt_grid.grid_x), TARGET_NAME + ) elif isinstance(tgt_grid, MeshRecord): tgt_mesh, tgt_location = tgt_grid tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME) diff --git a/esmf_regrid/schemes.py b/esmf_regrid/schemes.py index ff064f1d..0f348ba7 100644 --- a/esmf_regrid/schemes.py +++ b/esmf_regrid/schemes.py @@ -966,6 +966,8 @@ def regridder( self, src_grid, tgt_grid, + src_resolution=None, + tgt_resolution=None, use_src_mask=None, use_tgt_mask=None, tgt_location="face", @@ -980,6 +982,11 @@ def regridder( tgt_grid : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.Mesh` The unstructured :class:`~iris.cube.Cube`or :class:`~iris.experimental.ugrid.Mesh` defining the target. + src_resolution, tgt_resolution : int, optional + If present, represents the amount of latitude slices per source/target cell + given to ESMF for calculation. If resolution is set, ``src`` and ``tgt`` + respectively must have strictly increasing bounds (bounds may be transposed + plus or minus 360 degrees to make the bounds strictly increasing). use_src_mask : :obj:`~numpy.typing.ArrayLike` or bool, optional Array describing which elements :mod:`esmpy` will ignore on the src_grid. If True, the mask will be derived from src_grid. @@ -1017,6 +1024,8 @@ def regridder( src_grid, tgt_grid, mdtol=self.mdtol, + src_resolution=src_resolution, + tgt_resolution=tgt_resolution, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location="face", diff --git a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py index 3eff1cb8..cea1e95f 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py +++ b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py @@ -356,7 +356,7 @@ def test_generic_regridder(tmp_path, src_type, tgt_type, scheme): lat_bounds = (-90, 90) if src_type == "grid": src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True) - else: + elif src_type == "mesh": src = _gridlike_mesh_cube(n_lons_src, n_lats_src) if tgt_type == "grid": tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True) @@ -376,3 +376,82 @@ def test_generic_regridder(tmp_path, src_type, tgt_type, scheme): assert original_rg.src_resolution == loaded_rg.src_resolution assert original_rg.tgt_resolution == loaded_rg.tgt_resolution assert original_rg.mdtol == loaded_rg.mdtol + + +@pytest.mark.parametrize( + "src_type,tgt_type", + [ + ("grid", "grid"), + ("grid", "mesh"), + ("mesh", "grid"), + ("mesh", "mesh"), + ], +) +@pytest.mark.parametrize( + "scheme", + [ESMFAreaWeighted, ESMFBilinear, ESMFNearest], + ids=["conservative", "linear", "nearest"], +) +def test_generic_regridder_masked(tmp_path, src_type, tgt_type, scheme): + """Test save/load round tripping for regridders in `esmf_regrid.schemes`.""" + n_lons_src = 6 + n_lons_tgt = 3 + n_lats_src = 4 + n_lats_tgt = 2 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + if src_type == "grid": + src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True) + src.data = ma.array(src.data) + src.data[0, 0] = ma.masked + elif src_type == "mesh": + src = _gridlike_mesh_cube(n_lons_src, n_lats_src) + src.data = ma.array(src.data) + src.data[0] = ma.masked + if tgt_type == "grid": + tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True) + tgt.data = ma.array(tgt.data) + tgt.data[0, 0] = ma.masked + elif tgt_type == "mesh": + tgt = _gridlike_mesh_cube(n_lons_tgt, n_lats_tgt) + tgt.data = ma.array(tgt.data) + tgt.data[0] = ma.masked + + original_rg = scheme().regridder(src, tgt, use_src_mask=True, use_tgt_mask=True) + filename = tmp_path / "regridder.nc" + save_regridder(original_rg, filename) + loaded_rg = load_regridder(str(filename)) + + assert np.allclose(original_rg.src_mask, loaded_rg.src_mask) + assert np.allclose(original_rg.tgt_mask, loaded_rg.tgt_mask) + + +@pytest.mark.parametrize( + "scheme", + [ESMFAreaWeighted], + ids=["conservative"], +) +def test_generic_regridder_resolution(tmp_path, scheme): + """Test save/load round tripping for regridders in `esmf_regrid.schemes`.""" + n_lons_src = 6 + n_lons_tgt = 3 + n_lats_src = 4 + n_lats_tgt = 2 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True) + tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True) + src_resolution = 3 + tgt_resolution = 4 + + original_rg = scheme().regridder( + src, tgt, src_resolution=src_resolution, tgt_resolution=tgt_resolution + ) + filename = tmp_path / "regridder.nc" + save_regridder(original_rg, filename) + loaded_rg = load_regridder(str(filename)) + + assert loaded_rg.src_resolution == src_resolution + assert loaded_rg.regridder.src.resolution == src_resolution + assert loaded_rg.tgt_resolution == tgt_resolution + assert loaded_rg.regridder.tgt.resolution == tgt_resolution diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py index c5d77ae9..8d68958e 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py @@ -290,3 +290,28 @@ def test_masks(): weights_src_masked[:, 1:].todense(), weights_unmasked[:, 1:].todense() ) assert np.allclose(weights_tgt_masked[1:].todense(), weights_unmasked[1:].todense()) + + +def test_resolution(): + """ + Test calling of :class:`esmf_regrid.schemes.ESMFAreaWeightedRegridder`. + + Checks that the regridder accepts resolution arguments. + """ + n_lons = 6 + n_lats = 5 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + src = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) + tgt = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) + + src_resolution = 3 + tgt_resolution = 4 + + regridder = ESMFAreaWeightedRegridder( + src, tgt, src_resolution=src_resolution, tgt_resolution=tgt_resolution + ) + assert regridder.src_resolution == src_resolution + assert regridder.regridder.src.resolution == src_resolution + assert regridder.tgt_resolution == tgt_resolution + assert regridder.regridder.tgt.resolution == tgt_resolution From 5377dd79a4b7cdef8ff1da575027d4ce82b72675 Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Wed, 29 May 2024 16:00:46 +0100 Subject: [PATCH 5/9] generalise existing tests --- .../experimental/io/test_round_tripping.py | 203 +++++++++++++----- 1 file changed, 148 insertions(+), 55 deletions(-) diff --git a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py index cea1e95f..f17fd941 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py +++ b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py @@ -4,7 +4,13 @@ from numpy import ma import pytest -from esmf_regrid import Constants, ESMFAreaWeighted, ESMFBilinear, ESMFNearest +from esmf_regrid import ( + Constants, + ESMFAreaWeighted, + ESMFAreaWeightedRegridder, + ESMFBilinear, + ESMFNearest, +) from esmf_regrid.experimental.io import load_regridder, save_regridder from esmf_regrid.experimental.unstructured_scheme import ( GridToMeshESMFRegridder, @@ -21,6 +27,7 @@ def _make_grid_to_mesh_regridder( method=Constants.Method.CONSERVATIVE, + regridder=GridToMeshESMFRegridder, resolution=None, grid_dims=1, circular=True, @@ -60,20 +67,21 @@ def _make_grid_to_mesh_regridder( use_src_mask = False use_tgt_mask = False - rg = GridToMeshESMFRegridder( - src, - tgt, - method=method, - mdtol=0.5, - src_resolution=resolution, - use_src_mask=use_src_mask, - use_tgt_mask=use_tgt_mask, - ) + kwargs = { + "mdtol": 0.5, + "src_resolution": resolution, + "use_src_mask": use_src_mask, + "use_tgt_mask": use_tgt_mask, + } + if regridder == GridToMeshESMFRegridder: + kwargs["method"] = method + rg = regridder(src, tgt, **kwargs) return rg, src def _make_mesh_to_grid_regridder( method=Constants.Method.CONSERVATIVE, + regridder=MeshToGridESMFRegridder, resolution=None, grid_dims=1, circular=True, @@ -83,7 +91,10 @@ def _make_mesh_to_grid_regridder( src_lats = 4 tgt_lons = 5 tgt_lats = 6 - lon_bounds = (-180, 180) + if circular: + lon_bounds = (-180, 180) + else: + lon_bounds = (-180, 170) lat_bounds = (-90, 90) if grid_dims == 1: tgt = _grid_cube(tgt_lons, tgt_lats, lon_bounds, lat_bounds, circular=circular) @@ -110,14 +121,18 @@ def _make_mesh_to_grid_regridder( use_src_mask = False use_tgt_mask = False - rg = MeshToGridESMFRegridder( + kwargs = { + "mdtol": 0.5, + "tgt_resolution": resolution, + "use_src_mask": use_src_mask, + "use_tgt_mask": use_tgt_mask, + } + if regridder == MeshToGridESMFRegridder: + kwargs["method"] = method + rg = regridder( src, tgt, - method=method, - mdtol=0.5, - tgt_resolution=resolution, - use_src_mask=use_src_mask, - use_tgt_mask=use_tgt_mask, + **kwargs, ) return rg, src @@ -130,25 +145,33 @@ def _compare_ignoring_var_names(x, y): @pytest.mark.parametrize( - "method", + "method,regridder", [ - Constants.Method.CONSERVATIVE, - Constants.Method.BILINEAR, - Constants.Method.NEAREST, + (Constants.Method.CONSERVATIVE, GridToMeshESMFRegridder), + (Constants.Method.BILINEAR, GridToMeshESMFRegridder), + (Constants.Method.NEAREST, GridToMeshESMFRegridder), + (None, ESMFAreaWeightedRegridder), ], ) -def test_GridToMeshESMFRegridder_round_trip(tmp_path, method): - """Test save/load round tripping for `GridToMeshESMFRegridder`.""" - original_rg, src = _make_grid_to_mesh_regridder(method=method, circular=True) +def test_grid_to_mesh_round_trip(tmp_path, method, regridder): + """Test save/load round tripping for grid to mesh regridding.""" + original_rg, src = _make_grid_to_mesh_regridder( + method=method, regridder=regridder, circular=True + ) filename = tmp_path / "regridder.nc" save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) - assert original_rg.location == loaded_rg.location + if regridder == GridToMeshESMFRegridder: + assert original_rg.location == loaded_rg.location + _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) + _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) + else: + assert original_rg._tgt.location == loaded_rg._tgt.location + _compare_ignoring_var_names(original_rg._src[0], loaded_rg._src[0]) + _compare_ignoring_var_names(original_rg._src[1], loaded_rg._src[1]) assert original_rg.method == loaded_rg.method assert original_rg.mdtol == loaded_rg.mdtol - _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) - _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) # TODO: uncomment when iris mesh comparison becomes available. # assert original_rg.mesh == loaded_rg.mesh @@ -188,25 +211,52 @@ def test_GridToMeshESMFRegridder_round_trip(tmp_path, method): original_res_rg.regridder.src.resolution == loaded_res_rg.regridder.src.resolution ) + elif regridder == ESMFAreaWeightedRegridder: + assert original_rg.src_resolution == loaded_rg.src_resolution + original_res_rg, _ = _make_grid_to_mesh_regridder( + regridder=regridder, resolution=8 + ) + res_filename = tmp_path / "regridder_res.nc" + save_regridder(original_res_rg, res_filename) + loaded_res_rg = load_regridder(str(res_filename)) + assert original_res_rg.src_resolution == loaded_res_rg.src_resolution + assert ( + original_res_rg.regridder.src.resolution + == loaded_res_rg.regridder.src.resolution + ) # Ensure grid equality for non-circular coords. - original_nc_rg, _ = _make_grid_to_mesh_regridder(method=method, circular=False) + original_nc_rg, src = _make_grid_to_mesh_regridder( + method=method, regridder=regridder, circular=True + ) nc_filename = tmp_path / "non_circular_regridder.nc" save_regridder(original_nc_rg, nc_filename) loaded_nc_rg = load_regridder(str(nc_filename)) - _compare_ignoring_var_names(original_nc_rg.grid_x, loaded_nc_rg.grid_x) - _compare_ignoring_var_names(original_nc_rg.grid_y, loaded_nc_rg.grid_y) + if regridder == GridToMeshESMFRegridder: + _compare_ignoring_var_names(original_nc_rg.grid_x, loaded_nc_rg.grid_x) + _compare_ignoring_var_names(original_nc_rg.grid_y, loaded_nc_rg.grid_y) + else: + _compare_ignoring_var_names(original_nc_rg._src[0], loaded_nc_rg._src[0]) + _compare_ignoring_var_names(original_nc_rg._src[1], loaded_nc_rg._src[1]) -def test_GridToMeshESMFRegridder_curvilinear_round_trip(tmp_path): - """Test save/load round tripping for `GridToMeshESMFRegridder`.""" - original_rg, src = _make_grid_to_mesh_regridder(grid_dims=2) +@pytest.mark.parametrize( + "regridder", + [GridToMeshESMFRegridder, ESMFAreaWeightedRegridder], +) +def test_grid_to_mesh_curvilinear_round_trip(tmp_path, regridder): + """Test save/load round tripping for grid to mesh regridding.""" + original_rg, src = _make_grid_to_mesh_regridder(regridder=regridder, grid_dims=2) filename = tmp_path / "regridder.nc" save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) - _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) - _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) + if regridder == GridToMeshESMFRegridder: + _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) + _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) + else: + _compare_ignoring_var_names(original_rg._src[0], loaded_rg._src[0]) + _compare_ignoring_var_names(original_rg._src[1], loaded_rg._src[1]) # Demonstrate regridding still gives the same results. src_data = ma.arange(np.product(src.data.shape)).reshape(src.data.shape) @@ -220,14 +270,21 @@ def test_GridToMeshESMFRegridder_curvilinear_round_trip(tmp_path): # TODO: parametrize the rest of the tests in this module. +@pytest.mark.parametrize( + "regridder", + ["unstructured", ESMFAreaWeightedRegridder], +) @pytest.mark.parametrize( "rg_maker", [_make_grid_to_mesh_regridder, _make_mesh_to_grid_regridder], ids=["grid_to_mesh", "mesh_to_grid"], ) -def test_MeshESMFRegridder_masked_round_trip(tmp_path, rg_maker): +def test_MeshESMFRegridder_masked_round_trip(tmp_path, rg_maker, regridder): """Test save/load round tripping for the Mesh regridder classes.""" - original_rg, src = rg_maker(masks=True) + if regridder == "unstructured": + original_rg, src = rg_maker(masks=True) + else: + original_rg, src = rg_maker(regridder=regridder, masks=True) filename = tmp_path / "regridder.nc" save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) @@ -245,25 +302,34 @@ def test_MeshESMFRegridder_masked_round_trip(tmp_path, rg_maker): @pytest.mark.parametrize( - "method", + "method,regridder", [ - Constants.Method.CONSERVATIVE, - Constants.Method.BILINEAR, - Constants.Method.NEAREST, + (Constants.Method.CONSERVATIVE, MeshToGridESMFRegridder), + (Constants.Method.BILINEAR, MeshToGridESMFRegridder), + (Constants.Method.NEAREST, MeshToGridESMFRegridder), + (None, ESMFAreaWeightedRegridder), ], ) -def test_MeshToGridESMFRegridder_round_trip(tmp_path, method): - """Test save/load round tripping for `MeshToGridESMFRegridder`.""" - original_rg, src = _make_mesh_to_grid_regridder(method=method, circular=True) +def test_mesh_to_grid_round_trip(tmp_path, method, regridder): + """Test save/load round tripping for mesh to grid regridding.""" + original_rg, src = _make_mesh_to_grid_regridder( + method=method, regridder=regridder, circular=True + ) filename = tmp_path / "regridder.nc" save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) - assert original_rg.location == loaded_rg.location + if regridder == MeshToGridESMFRegridder: + assert original_rg.location == loaded_rg.location + _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) + _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) + else: + assert original_rg._src.location == loaded_rg._src.location + _compare_ignoring_var_names(original_rg._tgt[0], loaded_rg._tgt[0]) + _compare_ignoring_var_names(original_rg._tgt[1], loaded_rg._tgt[1]) + assert original_rg.method == loaded_rg.method assert original_rg.mdtol == loaded_rg.mdtol - _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) - _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) # TODO: uncomment when iris mesh comparison becomes available. # assert original_rg.mesh == loaded_rg.mesh @@ -302,25 +368,52 @@ def test_MeshToGridESMFRegridder_round_trip(tmp_path, method): original_res_rg.regridder.tgt.resolution == loaded_res_rg.regridder.tgt.resolution ) + elif regridder == ESMFAreaWeightedRegridder: + assert original_rg.src_resolution == loaded_rg.src_resolution + original_res_rg, _ = _make_mesh_to_grid_regridder( + regridder=regridder, resolution=8 + ) + res_filename = tmp_path / "regridder_res.nc" + save_regridder(original_res_rg, res_filename) + loaded_res_rg = load_regridder(str(res_filename)) + assert original_res_rg.tgt_resolution == loaded_res_rg.tgt_resolution + assert ( + original_res_rg.regridder.tgt.resolution + == loaded_res_rg.regridder.tgt.resolution + ) # Ensure grid equality for non-circular coords. - original_nc_rg, _ = _make_grid_to_mesh_regridder(method=method, circular=False) + original_nc_rg, _ = _make_mesh_to_grid_regridder( + method=method, regridder=regridder, circular=False + ) nc_filename = tmp_path / "non_circular_regridder.nc" save_regridder(original_nc_rg, nc_filename) loaded_nc_rg = load_regridder(str(nc_filename)) - _compare_ignoring_var_names(original_nc_rg.grid_x, loaded_nc_rg.grid_x) - _compare_ignoring_var_names(original_nc_rg.grid_y, loaded_nc_rg.grid_y) + if regridder == MeshToGridESMFRegridder: + _compare_ignoring_var_names(original_nc_rg.grid_x, loaded_nc_rg.grid_x) + _compare_ignoring_var_names(original_nc_rg.grid_y, loaded_nc_rg.grid_y) + else: + _compare_ignoring_var_names(original_nc_rg._tgt[0], loaded_nc_rg._tgt[0]) + _compare_ignoring_var_names(original_nc_rg._tgt[1], loaded_nc_rg._tgt[1]) -def test_MeshToGridESMFRegridder_curvilinear_round_trip(tmp_path): - """Test save/load round tripping for `MeshToGridESMFRegridder`.""" - original_rg, src = _make_mesh_to_grid_regridder(grid_dims=2) +@pytest.mark.parametrize( + "regridder", + [MeshToGridESMFRegridder, ESMFAreaWeightedRegridder], +) +def test_mesh_to_grid_curvilinear_round_trip(tmp_path, regridder): + """Test save/load round tripping for mesh to grid regridding.""" + original_rg, src = _make_mesh_to_grid_regridder(regridder=regridder, grid_dims=2) filename = tmp_path / "regridder.nc" save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) - _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) - _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) + if regridder == MeshToGridESMFRegridder: + _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) + _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) + else: + _compare_ignoring_var_names(original_rg._tgt[0], loaded_rg._tgt[0]) + _compare_ignoring_var_names(original_rg._tgt[1], loaded_rg._tgt[1]) # Demonstrate regridding still gives the same results. src_data = ma.arange(np.product(src.data.shape)).reshape(src.data.shape) From 05209672d69270925eacbe4a7772d05288d07adf Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Thu, 30 May 2024 14:27:19 +0100 Subject: [PATCH 6/9] review comments, test for _managed_var_name --- esmf_regrid/experimental/io.py | 13 +++-- .../experimental/io/test_save_regridder.py | 53 ++++++++++++++++++- 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/esmf_regrid/experimental/io.py b/esmf_regrid/experimental/io.py index c69dc9cf..2827becc 100644 --- a/esmf_regrid/experimental/io.py +++ b/esmf_regrid/experimental/io.py @@ -104,14 +104,17 @@ def save_regridder(rg, filename): """ Save a regridder scheme instance. - Saves either a - :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder` - or a - :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`. + Saves any of the regridder classes, i.e. + :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`, + :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`, + :class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`, + :class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or + :class:`~esmf_regrid.schemes.ESMFNearestRegridder`. + . Parameters ---------- - rg : :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder` or :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder` + rg : :class:`~esmf_regrid.schemes._ESMFRegridder` The regridder instance to save. filename : str The file name to save to. diff --git a/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py b/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py index 234fb604..492f554d 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py +++ b/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py @@ -2,7 +2,11 @@ import pytest -from esmf_regrid.experimental.io import save_regridder +from esmf_regrid.experimental.io import save_regridder, _managed_var_name +from esmf_regrid.schemes import ESMFAreaWeightedRegridder +from esmf_regrid.tests.unit.schemes.test__mesh_to_MeshInfo import ( + _gridlike_mesh_cube, +) def test_invalid_type(tmp_path): @@ -11,3 +15,50 @@ def test_invalid_type(tmp_path): filename = tmp_path / "regridder.nc" with pytest.raises(TypeError): save_regridder(invalid_obj, filename) + + +def test_var_name_preserve(tmp_path): + """Test that `save_regridder` does not change var_ames.""" + lons = 3 + lats = 4 + src = _gridlike_mesh_cube(lons, lats) + tgt = _gridlike_mesh_cube(lons, lats) + + DUMMY_VAR_NAME_SRC = "src_dummy_var" + DUMMY_VAR_NAME_TGT = "tgt_dummy_var" + for coord in src.mesh.coords(): + coord.var_name = DUMMY_VAR_NAME_SRC + for coord in tgt.mesh.coords(): + coord.var_name = DUMMY_VAR_NAME_TGT + + rg = ESMFAreaWeightedRegridder(src, tgt) + filename = tmp_path / "regridder.nc" + save_regridder(rg, filename) + + for coord in src.mesh.coords(): + assert coord.var_name == DUMMY_VAR_NAME_SRC + for coord in tgt.mesh.coords(): + assert coord.var_name == DUMMY_VAR_NAME_TGT + + +def test_managed_var_name(): + """Test that `_managed_var_name` changes var_names.""" + lons = 3 + lats = 4 + src = _gridlike_mesh_cube(lons, lats) + tgt = _gridlike_mesh_cube(lons, lats) + + DUMMY_VAR_NAME_SRC = "src_dummy_var" + DUMMY_VAR_NAME_TGT = "tgt_dummy_var" + for coord in src.mesh.coords(): + coord.var_name = DUMMY_VAR_NAME_SRC + for coord in tgt.mesh.coords(): + coord.var_name = DUMMY_VAR_NAME_TGT + + with _managed_var_name(src, tgt): + for coord in src.mesh.coords(): + print(coord) + assert coord.var_name != DUMMY_VAR_NAME_SRC + for coord in tgt.mesh.coords(): + print(coord) + assert coord.var_name != DUMMY_VAR_NAME_TGT From d812343c91448902b401a336532958fadb3f60bd Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Thu, 30 May 2024 14:40:03 +0100 Subject: [PATCH 7/9] address review comment --- esmf_regrid/experimental/io.py | 7 ++++--- .../tests/unit/experimental/io/test_save_regridder.py | 2 -- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/esmf_regrid/experimental/io.py b/esmf_regrid/experimental/io.py index 2827becc..490d0485 100644 --- a/esmf_regrid/experimental/io.py +++ b/esmf_regrid/experimental/io.py @@ -97,7 +97,6 @@ def _clean_var_names(cube): coord.var_name = None for con in cube.mesh.connectivities(): con.var_name = None - return cube def save_regridder(rg, filename): @@ -293,8 +292,10 @@ def load_regridder(filename): cubes = iris.load(filename) # Extract the source, target and metadata information. - src_cube = _clean_var_names(cubes.extract_cube(SOURCE_NAME)) - tgt_cube = _clean_var_names(cubes.extract_cube(TARGET_NAME)) + src_cube = cubes.extract_cube(SOURCE_NAME) + _clean_var_names(src_cube) + tgt_cube = cubes.extract_cube(TARGET_NAME) + _clean_var_names(tgt_cube) weights_cube = cubes.extract_cube(WEIGHTS_NAME) weight_shape_cube = cubes.extract_cube(WEIGHTS_SHAPE_NAME) diff --git a/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py b/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py index 492f554d..c234e9c8 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py +++ b/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py @@ -57,8 +57,6 @@ def test_managed_var_name(): with _managed_var_name(src, tgt): for coord in src.mesh.coords(): - print(coord) assert coord.var_name != DUMMY_VAR_NAME_SRC for coord in tgt.mesh.coords(): - print(coord) assert coord.var_name != DUMMY_VAR_NAME_TGT From ab2ee75ee9e92ec0f552dd4dde60414dfcc9d3e2 Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Thu, 30 May 2024 14:54:52 +0100 Subject: [PATCH 8/9] add changelog --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 52a60ab8..19c6afed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] +### Added + +- [PR#357](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/357) + Added support for saving and loading of `ESMFAreaWeighted`, `ESMFBilinear` + and `ESMFNearest` regridders. + [@stephenworsley](https://github.com/stephenworsley) + ### Changed - [PR#361](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/361) From f07780916cf02c13ccf52f47b428977fb553c6b5 Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Thu, 30 May 2024 15:05:46 +0100 Subject: [PATCH 9/9] flake 8 --- esmf_regrid/tests/unit/experimental/io/test_save_regridder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py b/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py index c234e9c8..68420790 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py +++ b/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py @@ -2,7 +2,7 @@ import pytest -from esmf_regrid.experimental.io import save_regridder, _managed_var_name +from esmf_regrid.experimental.io import _managed_var_name, save_regridder from esmf_regrid.schemes import ESMFAreaWeightedRegridder from esmf_regrid.tests.unit.schemes.test__mesh_to_MeshInfo import ( _gridlike_mesh_cube,