Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy rectilinear interpolator #6084

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
134 changes: 81 additions & 53 deletions lib/iris/analysis/_interpolation.py
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried writing a benchmark to demonstrate the benefits of this but everything either stayed the same or got slower.

Could you share with me the results you have seen? Especially if it's something I/we could turn into a benchmark. Thanks

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from numpy.lib.stride_tricks import as_strided
import numpy.ma as ma

from iris._lazy_data import map_complete_blocks
from iris.analysis._scipy_interpolate import _RegularGridInterpolator
fnattino marked this conversation as resolved.
Show resolved Hide resolved
from iris.coords import AuxCoord, DimCoord
import iris.util

Expand Down Expand Up @@ -163,6 +165,15 @@ def snapshot_grid(cube):
return x.copy(), y.copy()


def _interpolated_dtype(dtype, method):
"""Determine the minimum base dtype required by the underlying interpolator."""
if method == "nearest":
result = dtype
else:
result = np.result_type(_DEFAULT_DTYPE, dtype)
return result
Comment on lines +167 to +173
Copy link
Contributor

@trexfeathers trexfeathers Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this needs to stay (see my other comment about args=[self] - #6084 (comment)), then I'd be interested in us unifying this function with RectilinearInterpolator._interpolated_dtype().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to put this back as a (static)method of the RectilinearInterpolator? Or to merge the body of this function with RectilinearInterpolator._interpolate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean find any appropriate way for there to be only 1 _interpolated_dtype() in this file, which can be used by both:

  • RectilinearInterpolator._interpolate()
  • RectilinearInterpolator._points()

Up to you how this is achieved.



class RectilinearInterpolator:
"""Provide support for performing nearest-neighbour or linear interpolation.

Expand Down Expand Up @@ -200,13 +211,8 @@ def __init__(self, src_cube, coords, method, extrapolation_mode):
set to NaN.

"""
# Trigger any deferred loading of the source cube's data and snapshot
# its state to ensure that the interpolator is impervious to external
# changes to the original source cube. The data is loaded to prevent
# the snapshot having lazy data, avoiding the potential for the
# same data to be loaded again and again.
if src_cube.has_lazy_data():
src_cube.data
# Snapshot the cube state to ensure that the interpolator is impervious
# to external changes to the original source cube.
self._src_cube = src_cube.copy()
# Coordinates defining the dimensions to be interpolated.
self._src_coords = [self._src_cube.coord(coord) for coord in coords]
Expand Down Expand Up @@ -277,17 +283,27 @@ def _account_for_inverted(self, data):
data = data[tuple(dim_slices)]
return data

def _interpolate(self, data, interp_points):
@staticmethod
def _interpolate(
data,
src_points,
interp_points,
interp_shape,
method="linear",
extrapolation_mode="nanmask",
):
"""Interpolate a data array over N dimensions.

Create and cache the underlying interpolator instance before invoking
it to perform interpolation over the data at the given coordinate point
values.
Create the interpolator instance before invoking it to perform
interpolation over the data at the given coordinate point values.

Parameters
----------
data : ndarray
A data array, to be interpolated in its first 'N' dimensions.
src_points :
The point values defining the dimensions to be interpolated.
(len(src_points) should be N).
interp_points : ndarray
An array of interpolation coordinate values.
Its shape is (..., N) where N is the number of interpolation
Expand All @@ -296,44 +312,51 @@ def _interpolate(self, data, interp_points):
coordinate, which is mapped to the i'th data dimension.
The other (leading) dimensions index over the different required
sample points.
interp_shape :
The shape of the interpolated array in its first 'N' dimensions
(len(interp_shape) should be N).
method: str
Interpolation method (see :class:`iris.analysis._interpolation.RectilinearInterpolator`)
extrapolation_mode: str
Extrapolation mode (see :class:`iris.analysis._interpolation.RectilinearInterpolator`)

Returns
-------
:class:`np.ndarray`.
Its shape is "points_shape + extra_shape",
Its shape is "interp_shape + extra_shape",
where "extra_shape" is the remaining non-interpolated dimensions of
the data array (i.e. 'data.shape[N:]'), and "points_shape" is the
leading dimensions of interp_points,
(i.e. 'interp_points.shape[:-1]').

the data array (i.e. 'data.shape[N:]').
"""
from iris.analysis._scipy_interpolate import _RegularGridInterpolator

dtype = self._interpolated_dtype(data.dtype)
dtype = _interpolated_dtype(data.dtype, method)
if data.dtype != dtype:
# Perform dtype promotion.
data = data.astype(dtype)

mode = EXTRAPOLATION_MODES[self._mode]
if self._interpolator is None:
# Cache the interpolator instance.
# NB. The constructor of the _RegularGridInterpolator class does
# some unnecessary checks on the fill_value parameter,
# so we set it afterwards instead. Sneaky. ;-)
self._interpolator = _RegularGridInterpolator(
self._src_points,
data,
method=self.method,
bounds_error=mode.bounds_error,
fill_value=None,
)
else:
self._interpolator.values = data
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
# Determine the shape of the interpolated result.
ndims_interp = len(interp_shape)
extra_shape = data.shape[ndims_interp:]
final_shape = [*interp_shape, *extra_shape]
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved

mode = EXTRAPOLATION_MODES[extrapolation_mode]
_data = np.ma.getdata(data)
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
# NB. The constructor of the _RegularGridInterpolator class does
# some unnecessary checks on the fill_value parameter,
# so we set it afterwards instead. Sneaky. ;-)
interpolator = _RegularGridInterpolator(
src_points,
_data,
method=method,
bounds_error=mode.bounds_error,
fill_value=None,
)
interpolator.fill_value = mode.fill_value
result = interpolator(interp_points)

# We may be re-using a cached interpolator, so ensure the fill
# value is set appropriately for extrapolating data values.
self._interpolator.fill_value = mode.fill_value
result = self._interpolator(interp_points)
# The interpolated result has now shape "points_shape + extra_shape"
# where "points_shape" is the leading dimension of "interp_points"
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
# (i.e. 'interp_points.shape[:-1]'). We reshape it to match the shape
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
# of the interpolated dimensions.
result = result.reshape(final_shape)
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved

if result.dtype != data.dtype:
# Cast the data dtype to be as expected. Note that, the dtype
Expand All @@ -346,13 +369,11 @@ def _interpolate(self, data, interp_points):
# `data` is not a masked array.
src_mask = np.ma.getmaskarray(data)
# Switch the extrapolation to work with mask values.
self._interpolator.fill_value = mode.mask_fill_value
self._interpolator.values = src_mask
mask_fraction = self._interpolator(interp_points)
interpolator.fill_value = mode.mask_fill_value
interpolator.values = src_mask
mask_fraction = interpolator(interp_points)
new_mask = mask_fraction > 0
if ma.isMaskedArray(data) or np.any(new_mask):
result = np.ma.MaskedArray(result, new_mask)

result = np.ma.MaskedArray(result, new_mask)
fnattino marked this conversation as resolved.
Show resolved Hide resolved
return result

def _resample_coord(self, sample_points, coord, coord_dims):
Expand Down Expand Up @@ -530,7 +551,7 @@ def _points(self, sample_points, data, data_dims=None):
_, src_order = zip(*sorted(dmap.items(), key=operator.itemgetter(0)))

# Prepare the sample points for interpolation and calculate the
# shape of the interpolated result.
# shape of the interpolated dimensions.
interp_points = []
interp_shape = []
for index, points in enumerate(sample_points):
Expand All @@ -539,10 +560,6 @@ def _points(self, sample_points, data, data_dims=None):
interp_points.append(points)
interp_shape.append(points.size)

interp_shape.extend(
length for dim, length in enumerate(data.shape) if dim not in di
)

# Convert the interpolation points into a cross-product array
# with shape (n_cross_points, n_dims)
interp_points = np.asarray([pts for pts in product(*interp_points)])
Expand All @@ -554,9 +571,20 @@ def _points(self, sample_points, data, data_dims=None):
# Transpose data in preparation for interpolation.
data = np.transpose(data, interp_order)

# Interpolate and reshape the data ...
result = self._interpolate(data, interp_points)
result = result.reshape(interp_shape)
# Interpolate the data, meging the chunks in the interpolated
# dimensions.
dims_merge_chunks = [dmap[d] for d in di]
result = map_complete_blocks(
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
data,
self._interpolate,
dims=dims_merge_chunks,
out_sizes=interp_shape,
src_points=self._src_points,
interp_points=interp_points,
interp_shape=interp_shape,
method=self._method,
extrapolation_mode=self._mode,
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
)

if src_order != dims:
# Restore the interpolated result to the original
Expand Down Expand Up @@ -592,7 +620,7 @@ def __call__(self, sample_points, collapse_scalar=True):

sample_points = _canonical_sample_points(self._src_coords, sample_points)

data = self._src_cube.data
data = self._src_cube.core_data()
# Interpolate the cube payload.
interpolated_data = self._points(sample_points, data)

Expand Down
Loading