Skip to content

Commit

Permalink
Ensure dtype is preserved after regridding (#239)
Browse files Browse the repository at this point in the history
* ensure dtype is preserved after regridding

* add to changelog

* convert dtype when source is int type

* revert dtype handling, fix dask dtype handling

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix dtype

* add tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address review comments

* flake 8 fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
stephenworsley and pre-commit-ci[bot] authored May 29, 2024
1 parent a50b2b8 commit cf86c1d
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 5 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
benchmarks.
[@stephenworsley](https://github.com/stephenworsley)

### Fixed
- [PR#239](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/239)
Ensured dtype is preserved by regridding.
[@stephenworsley](https://github.com/stephenworsley)

## [0.9] - 2023-11-03

### Added
Expand Down
9 changes: 8 additions & 1 deletion esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def __init__(
self.esmf_version = None
self.weight_matrix = precomputed_weights

def _out_dtype(self, in_dtype):
"""Return the expected output dtype for a given input dtype."""
weight_dtype = self.weight_matrix.dtype
out_dtype = (np.ones(1, dtype=in_dtype) * np.ones(1, dtype=weight_dtype)).dtype
return out_dtype

def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1):
"""
Perform regridding on an array of data.
Expand Down Expand Up @@ -175,12 +181,13 @@ def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1):
extra_size = max(1, np.prod(extra_shape))
src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array))
weight_sums = self.weight_matrix @ src_inverted_mask
out_dtype = self._out_dtype(src_array.dtype)
# Set the minimum mdtol to be slightly higher than 0 to account for rounding
# errors.
mdtol = max(mdtol, 1e-8)
tgt_mask = weight_sums > 1 - mdtol
masked_weight_sums = weight_sums * tgt_mask
normalisations = np.ones([self.tgt.size, extra_size])
normalisations = np.ones([self.tgt.size, extra_size], dtype=out_dtype)
if norm_type == Constants.NormType.FRACAREA:
normalisations[tgt_mask] /= masked_weight_sums[tgt_mask]
elif norm_type == Constants.NormType.DSTAREA:
Expand Down
22 changes: 20 additions & 2 deletions esmf_regrid/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def _regrid_along_dims(data, regridder, dims, num_out_dims, mdtol):
return result


def _map_complete_blocks(src, func, active_dims, out_sizes, *args, **kwargs):
def _map_complete_blocks(
src, func, active_dims, out_sizes, *args, dtype=None, **kwargs
):
"""
Apply a function to complete blocks.
Expand All @@ -299,6 +301,8 @@ def _map_complete_blocks(src, func, active_dims, out_sizes, *args, **kwargs):
Dimensions that cannot be chunked.
out_sizes : tuple of int
Output size of dimensions that cannot be chunked.
dtype : type, optional
Type of the output array, if not given, the dtype of src is used.
Returns
-------
Expand All @@ -311,6 +315,8 @@ def _map_complete_blocks(src, func, active_dims, out_sizes, *args, **kwargs):
return func(src.data, *args, **kwargs)

data = src.lazy_data()
if dtype is None:
dtype = data.dtype

# Ensure dims are not chunked
in_chunks = list(data.chunks)
Expand Down Expand Up @@ -373,7 +379,7 @@ def _map_complete_blocks(src, func, active_dims, out_sizes, *args, **kwargs):
chunks=out_chunks,
drop_axis=dropped_dims,
new_axis=new_axis,
dtype=src.dtype,
dtype=dtype,
**kwargs,
)

Expand Down Expand Up @@ -557,6 +563,8 @@ def _regrid_rectilinear_to_rectilinear__perform(src_cube, regrid_info, mdtol):
grid_x, grid_y = regrid_info.target
regridder = regrid_info.regridder

out_dtype = regridder._out_dtype(src_cube.dtype)

# Apply regrid to all the chunks of src_cube, ensuring first that all
# chunks cover the entire horizontal plane (otherwise they would break
# the regrid function).
Expand All @@ -574,6 +582,7 @@ def _regrid_rectilinear_to_rectilinear__perform(src_cube, regrid_info, mdtol):
dims=[grid_x_dim, grid_y_dim],
num_out_dims=2,
mdtol=mdtol,
dtype=out_dtype,
)

new_cube = _create_cube(
Expand Down Expand Up @@ -636,6 +645,8 @@ def _regrid_unstructured_to_rectilinear__perform(src_cube, regrid_info, mdtol):
grid_x, grid_y = regrid_info.target
regridder = regrid_info.regridder

out_dtype = regridder._out_dtype(src_cube.dtype)

# Apply regrid to all the chunks of src_cube, ensuring first that all
# chunks cover the entire horizontal plane (otherwise they would break
# the regrid function).
Expand All @@ -653,6 +664,7 @@ def _regrid_unstructured_to_rectilinear__perform(src_cube, regrid_info, mdtol):
dims=[mesh_dim],
num_out_dims=2,
mdtol=mdtol,
dtype=out_dtype,
)

new_cube = _create_cube(
Expand Down Expand Up @@ -739,6 +751,8 @@ def _regrid_rectilinear_to_unstructured__perform(src_cube, regrid_info, mdtol):
else:
raise NotImplementedError(f"Unrecognised location {location}.")

out_dtype = regridder._out_dtype(src_cube.dtype)

# Apply regrid to all the chunks of src_cube, ensuring first that all
# chunks cover the entire horizontal plane (otherwise they would break
# the regrid function).
Expand All @@ -751,6 +765,7 @@ def _regrid_rectilinear_to_unstructured__perform(src_cube, regrid_info, mdtol):
dims=[grid_x_dim, grid_y_dim],
num_out_dims=1,
mdtol=mdtol,
dtype=out_dtype,
)

new_cube = _create_cube(
Expand Down Expand Up @@ -823,6 +838,8 @@ def _regrid_unstructured_to_unstructured__perform(src_cube, regrid_info, mdtol):
mesh, location = regrid_info.target
regridder = regrid_info.regridder

out_dtype = regridder._out_dtype(src_cube.dtype)

if location == "face":
face_node = mesh.face_node_connectivity
chunk_shape = (face_node.shape[face_node.location_axis],)
Expand All @@ -840,6 +857,7 @@ def _regrid_unstructured_to_unstructured__perform(src_cube, regrid_info, mdtol):
dims=[mesh_dim],
num_out_dims=1,
mdtol=mdtol,
dtype=out_dtype,
)

new_cube = _create_cube(
Expand Down
22 changes: 22 additions & 0 deletions esmf_regrid/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Common testing infrastructure."""

import pytest


@pytest.fixture(params=["float32", "float64"])
def in_dtype(request):
"""Fixture for controlling dtype."""
return request.param


@pytest.fixture(
params=[
("grid", "grid"),
("grid", "mesh"),
("mesh", "grid"),
("mesh", "mesh"),
]
)
def src_tgt_types(request):
"""Fixture for controlling type of source and target."""
return request.param
26 changes: 26 additions & 0 deletions esmf_regrid/tests/unit/esmf_regridder/test_Regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,29 @@ def _get_points(bounds):
(weights_dict["weights"], (weights_dict["row_dst"], weights_dict["col_src"]))
)
assert np.allclose(result.toarray(), expected_weights.toarray())


def test_Regridder_dtype_handling():
"""
Basic test for :meth:`~esmf_regrid.esmf_regridder.Regridder.regrid`.
Tests that dtype is handled as expected.
"""
lon, lat, lon_bounds, lat_bounds = make_grid_args(2, 3)
src_grid = GridInfo(lon, lat, lon_bounds, lat_bounds)

lon, lat, lon_bounds, lat_bounds = make_grid_args(3, 2)
tgt_grid = GridInfo(lon, lat, lon_bounds, lat_bounds)

# Set up the regridder with precomputed weights.
rg_64 = Regridder(src_grid, tgt_grid, precomputed_weights=_expected_weights())
weights_32 = _expected_weights().astype(np.float32)
rg_32 = Regridder(src_grid, tgt_grid, precomputed_weights=weights_32)

src_32 = np.ones([3, 2], dtype=np.float32)
src_64 = np.ones([3, 2], dtype=np.float64)

assert rg_64.regrid(src_64).dtype == np.float64
assert rg_64.regrid(src_32).dtype == np.float64
assert rg_32.regrid(src_64).dtype == np.float64
assert rg_32.regrid(src_32).dtype == np.float32
37 changes: 37 additions & 0 deletions esmf_regrid/tests/unit/schemes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Unit tests for `esmf_regrid.schemes`."""

import dask.array as da
from iris.coord_systems import OSGB
import numpy as np
from numpy import ma
Expand Down Expand Up @@ -215,3 +216,39 @@ def _test_non_degree_crs(scheme):

# Check that the number of masked points is as expected.
assert (1 - result.data.mask).sum() == expected_unmasked


def _test_dtype_handling(scheme, src_type, tgt_type, in_dtype):
"""Test regridding scheme handles dtype as expected."""
n_lons_src = 6
n_lons_tgt = 3
n_lats_src = 4
n_lats_tgt = 2
lon_bounds = (-180, 180)
lat_bounds = (-90, 90)
if in_dtype == "float32":
dtype = np.float32
elif in_dtype == "float64":
dtype = np.float64

if src_type == "grid":
src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True)
src_data = np.zeros([n_lats_src, n_lons_src], dtype=dtype)
src.data = da.array(src_data)
elif src_type == "mesh":
src = _gridlike_mesh_cube(n_lons_src, n_lats_src)
src_data = np.zeros([n_lats_src * n_lons_src], dtype=dtype)
src.data = da.array(src_data)
if tgt_type == "grid":
tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True)
elif tgt_type == "mesh":
tgt = _gridlike_mesh_cube(n_lons_tgt, n_lats_tgt)

result = src.regrid(tgt, scheme())

expected_dtype = np.float64

assert result.has_lazy_data()

assert result.lazy_data().dtype == expected_dtype
assert result.data.dtype == expected_dtype
7 changes: 7 additions & 0 deletions esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from esmf_regrid.schemes import ESMFAreaWeighted
from esmf_regrid.tests.unit.schemes.__init__ import (
_test_cube_regrid,
_test_dtype_handling,
_test_invalid_mdtol,
_test_mask_from_init,
_test_mask_from_regridder,
Expand Down Expand Up @@ -74,3 +75,9 @@ def test_invalid_tgt_location():
def test_non_degree_crs():
"""Test for coordinates with non-degree units."""
_test_non_degree_crs(ESMFAreaWeighted)


def test_dtype_handling(src_tgt_types, in_dtype):
"""Test regridding scheme handles dtype as expected."""
src_type, tgt_type = src_tgt_types
_test_dtype_handling(ESMFAreaWeighted, src_type, tgt_type, in_dtype)
7 changes: 7 additions & 0 deletions esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from esmf_regrid.schemes import ESMFBilinear
from esmf_regrid.tests.unit.schemes.__init__ import (
_test_cube_regrid,
_test_dtype_handling,
_test_invalid_mdtol,
_test_mask_from_init,
_test_mask_from_regridder,
Expand Down Expand Up @@ -63,3 +64,9 @@ def test_mask_from_regridder(mask_keyword):
def test_non_degree_crs():
"""Test for coordinates with non-degree units."""
_test_non_degree_crs(ESMFBilinear)


def test_dtype_handling(src_tgt_types, in_dtype):
"""Test regridding scheme handles dtype as expected."""
src_type, tgt_type = src_tgt_types
_test_dtype_handling(ESMFBilinear, src_type, tgt_type, in_dtype)
7 changes: 7 additions & 0 deletions esmf_regrid/tests/unit/schemes/test_ESMFNearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from esmf_regrid.schemes import ESMFNearest
from esmf_regrid.tests.unit.schemes.__init__ import (
_test_dtype_handling,
_test_mask_from_init,
_test_mask_from_regridder,
_test_non_degree_crs,
Expand Down Expand Up @@ -106,3 +107,9 @@ def test_mask_from_regridder(mask_keyword):
def test_non_degree_crs():
"""Test for coordinates with non-degree units."""
_test_non_degree_crs(ESMFNearest)


def test_dtype_handling(src_tgt_types, in_dtype):
"""Test regridding scheme handles dtype as expected."""
src_type, tgt_type = src_tgt_types
_test_dtype_handling(ESMFNearest, src_type, tgt_type, in_dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def test_laziness(src_transposed, tgt_transposed):
lat_bounds = (-90, 90)

grid = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True)
src_data = np.arange(n_lats * n_lons * h).reshape([n_lats, n_lons, h])
src_data = np.arange(n_lats * n_lons * h, dtype=np.float32).reshape(
[n_lats, n_lons, h]
)
src_data = da.from_array(src_data, chunks=[3, 5, 1])
src = Cube(src_data)
src.add_dim_coord(grid.coord("latitude"), 0)
Expand All @@ -185,6 +187,8 @@ def test_laziness(src_transposed, tgt_transposed):
assert src.has_lazy_data()
result = regrid_rectilinear_to_rectilinear(src, tgt)
assert result.has_lazy_data()
assert result.lazy_data().dtype == np.float64
assert result.data.dtype == np.float64
assert np.allclose(result.data, src_data)


Expand Down Expand Up @@ -227,7 +231,7 @@ def test_laziness_curvilinear(src_transposed, tgt_transposed):
extra = AuxCoord(np.arange(e), long_name="extra dim")

src_data = np.empty([h, src_lats, t, src_lons, e])
src_data[:] = np.arange(t * h * e).reshape([h, t, e])[
src_data[:] = np.arange(t * h * e, dtype=np.float32).reshape([h, t, e])[
:, np.newaxis, :, np.newaxis, :
]
src_data_lazy = da.array(src_data)
Expand All @@ -253,6 +257,8 @@ def test_laziness_curvilinear(src_transposed, tgt_transposed):
result_lazy = regrid_rectilinear_to_rectilinear(src_cube_lazy, tgt_grid)

assert result_lazy.has_lazy_data()
assert result.lazy_data().dtype == np.float64
assert result.data.dtype == np.float64

assert result_lazy == result

Expand Down

0 comments on commit cf86c1d

Please sign in to comment.