Skip to content

Commit

Permalink
extend regridder saving/loading
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenworsley committed May 1, 2024
1 parent 01369d1 commit e5ee8a3
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 41 deletions.
160 changes: 132 additions & 28 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Provides load/save functions for regridders."""

from contextlib import contextmanager

import iris
from iris.coords import AuxCoord
from iris.cube import Cube, CubeList
Expand All @@ -13,9 +15,19 @@
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
)
from esmf_regrid.schemes import (
ESMFAreaWeightedRegridder,
ESMFBilinearRegridder,
ESMFNearestRegridder,
MeshRecord,
GridRecord,
)


SUPPORTED_REGRIDDERS = [
ESMFAreaWeightedRegridder,
ESMFBilinearRegridder,
ESMFNearestRegridder,
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
]
Expand All @@ -34,6 +46,8 @@
MDTOL = "mdtol"
METHOD = "method"
RESOLUTION = "resolution"
SOURCE_RESOLUTION = "src_resolution"
TARGET_RESOLUTION = "tgt_resolution"


def _add_mask_to_cube(mask, cube, name):
Expand All @@ -43,6 +57,49 @@ def _add_mask_to_cube(mask, cube, name):
cube.add_aux_coord(mask_coord, list(range(cube.ndim)))


@contextmanager
def managed_var_name(src_cube, tgt_cube):
src_coord_names = []
src_mesh_coords = []
if src_cube.mesh is not None:
src_mesh = src_cube.mesh
src_mesh_coords = src_mesh.coords()
for coord in src_mesh_coords:
src_coord_names.append(coord.var_name)
tgt_coord_names = []
tgt_mesh_coords = []
if tgt_cube.mesh is not None:
tgt_mesh = tgt_cube.mesh
tgt_mesh_coords = tgt_mesh.coords()
for coord in tgt_mesh_coords:
tgt_coord_names.append(coord.var_name)

try:
for coord in src_mesh_coords:
coord.var_name = "_".join([SOURCE_NAME, "mesh", coord.name()])
for coord in tgt_mesh_coords:
coord.var_name = "_".join([TARGET_NAME, "mesh", coord.name()])
yield None
finally:
for coord, var_name in zip(src_mesh_coords, src_coord_names):
coord.var_name = var_name
for coord, var_name in zip(tgt_mesh_coords, tgt_coord_names):
coord.var_name = var_name


def _clean_var_names(cube):
cube.var_name = None
for coord in cube.coords():
coord.var_name = None
if cube.mesh is not None:
cube.mesh.var_name = None
for coord in cube.mesh.coords():
coord.var_name = None
for con in cube.mesh.connectivities():
con.var_name = None
return cube


def save_regridder(rg, filename):
"""
Save a regridder scheme instance.
Expand Down Expand Up @@ -76,28 +133,52 @@ def _standard_grid_cube(grid, name):
cube.add_aux_coord(grid[1], [0, 1])
return cube

if regridder_type == "GridToMeshESMFRegridder":
def _standard_mesh_cube(mesh, location, name):
mesh_coords = mesh.to_MeshCoords(location)
data = np.zeros(mesh_coords[0].points.shape[0])
cube = Cube(data, var_name=name, long_name=name)
for coord in mesh_coords:
cube.add_aux_coord(coord, 0)
return cube

if regridder_type in [
"ESMFAreaWeightedRegridder",
"ESMFBilinearRegridder",
"ESMFNearestRegridder",
]:
src_grid = rg._src
if isinstance(src_grid, GridRecord):
src_cube = _standard_grid_cube(src_grid, SOURCE_NAME)
elif isinstance(src_grid, MeshRecord):
src_mesh, src_location = src_grid
src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME)
else:
raise ValueError("Improper type for `rg._src`.")
_add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME)

tgt_grid = rg._tgt
if isinstance(tgt_grid, GridRecord):
tgt_cube = _standard_grid_cube(tgt_grid, TARGET_NAME)
elif isinstance(tgt_grid, MeshRecord):
tgt_mesh, tgt_location = tgt_grid
tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME)
else:
raise ValueError("Improper type for `rg._tgt`.")
_add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME)
elif regridder_type == "GridToMeshESMFRegridder":
src_grid = (rg.grid_y, rg.grid_x)
src_cube = _standard_grid_cube(src_grid, SOURCE_NAME)
_add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME)

tgt_mesh = rg.mesh
tgt_location = rg.location
tgt_mesh_coords = tgt_mesh.to_MeshCoords(tgt_location)
tgt_data = np.zeros(tgt_mesh_coords[0].points.shape[0])
tgt_cube = Cube(tgt_data, var_name=TARGET_NAME, long_name=TARGET_NAME)
for coord in tgt_mesh_coords:
tgt_cube.add_aux_coord(coord, 0)
tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME)
_add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME)

elif regridder_type == "MeshToGridESMFRegridder":
src_mesh = rg.mesh
src_location = rg.location
src_mesh_coords = src_mesh.to_MeshCoords(src_location)
src_data = np.zeros(src_mesh_coords[0].points.shape[0])
src_cube = Cube(src_data, var_name=SOURCE_NAME, long_name=SOURCE_NAME)
for coord in src_mesh_coords:
src_cube.add_aux_coord(coord, 0)
src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME)
_add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME)

tgt_grid = (rg.grid_y, rg.grid_x)
Expand All @@ -112,7 +193,18 @@ def _standard_grid_cube(grid, name):

method = str(check_method(rg.method).name)

resolution = rg.resolution
if regridder_type in ["GridToMeshESMFRegridder", "MeshToGridESMFRegridder"]:
resolution = rg.resolution
src_resolution = None
tgt_resolution = None
elif regridder_type == "ESMFAreaWeightedRegridder":
resolution = None
src_resolution = rg.src_resolution
tgt_resolution = rg.tgt_resolution
else:
resolution = None
src_resolution = None
tgt_resolution = None

weight_matrix = rg.regridder.weight_matrix
reformatted_weight_matrix = scipy.sparse.coo_matrix(weight_matrix)
Expand Down Expand Up @@ -141,6 +233,10 @@ def _standard_grid_cube(grid, name):
}
if resolution is not None:
attributes[RESOLUTION] = resolution
if src_resolution is not None:
attributes[SOURCE_RESOLUTION] = src_resolution
if tgt_resolution is not None:
attributes[TARGET_RESOLUTION] = tgt_resolution

weights_cube = Cube(weight_data, var_name=WEIGHTS_NAME, long_name=WEIGHTS_NAME)
row_coord = AuxCoord(
Expand All @@ -158,17 +254,14 @@ def _standard_grid_cube(grid, name):
long_name=WEIGHTS_SHAPE_NAME,
)

# Avoid saving bug by placing the mesh cube second.
# TODO: simplify this when this bug is fixed in iris.
if regridder_type == "GridToMeshESMFRegridder":
# Save cubes while ensuring var_names do not conflict for the sake of consistency.
with managed_var_name(src_cube, tgt_cube):
cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube])
elif regridder_type == "MeshToGridESMFRegridder":
cube_list = CubeList([tgt_cube, src_cube, weights_cube, weight_shape_cube])

for cube in cube_list:
cube.attributes = attributes
for cube in cube_list:
cube.attributes = attributes

iris.fileformats.netcdf.save(cube_list, filename)
iris.fileformats.netcdf.save(cube_list, filename)


def load_regridder(filename):
Expand All @@ -193,8 +286,8 @@ def load_regridder(filename):
cubes = iris.load(filename)

# Extract the source, target and metadata information.
src_cube = cubes.extract_cube(SOURCE_NAME)
tgt_cube = cubes.extract_cube(TARGET_NAME)
src_cube = _clean_var_names(cubes.extract_cube(SOURCE_NAME))
tgt_cube = _clean_var_names(cubes.extract_cube(TARGET_NAME))
weights_cube = cubes.extract_cube(WEIGHTS_NAME)
weight_shape_cube = cubes.extract_cube(WEIGHTS_SHAPE_NAME)

Expand All @@ -210,8 +303,12 @@ def load_regridder(filename):
)

resolution = weights_cube.attributes.get(RESOLUTION, None)
src_resolution = weights_cube.attributes.get(SOURCE_RESOLUTION, None)
tgt_resolution = weights_cube.attributes.get(TARGET_RESOLUTION, None)
if resolution is not None:
resolution = int(resolution)
src_resolution = int(src_resolution)
tgt_resolution = int(tgt_resolution)

# Reconstruct the weight matrix.
weight_data = weights_cube.data
Expand All @@ -234,18 +331,25 @@ def load_regridder(filename):
use_tgt_mask = False

if scheme is GridToMeshESMFRegridder:
resolution_keyword = "src_resolution"
resolution_keyword = SOURCE_RESOLUTION
kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol}
elif scheme is MeshToGridESMFRegridder:
resolution_keyword = "tgt_resolution"
resolution_keyword = TARGET_RESOLUTION
kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol}
elif scheme is ESMFAreaWeightedRegridder:
kwargs = {
SOURCE_RESOLUTION: src_resolution,
TARGET_RESOLUTION: tgt_resolution,
"mdtol": mdtol,
}
elif scheme is ESMFBilinearRegridder:
kwargs = {"mdtol": mdtol}
else:
raise NotImplementedError
kwargs = {resolution_keyword: resolution}
kwargs = {}

regridder = scheme(
src_cube,
tgt_cube,
mdtol=mdtol,
method=method,
precomputed_weights=weight_matrix,
use_src_mask=use_src_mask,
use_tgt_mask=use_tgt_mask,
Expand Down
2 changes: 2 additions & 0 deletions esmf_regrid/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,8 +1465,10 @@ def __init__(
if tgt_location is not "face".
"""
kwargs = dict()
self.src_resolution = src_resolution
if src_resolution is not None:
kwargs["src_resolution"] = src_resolution
self.tgt_resolution = tgt_resolution
if tgt_resolution is not None:
kwargs["tgt_resolution"] = tgt_resolution
if tgt_location is not None and tgt_location != "face":
Expand Down
Loading

0 comments on commit e5ee8a3

Please sign in to comment.