From 79d3f10aa857d40435b7033b6c4592e6e5bbe4c2 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 28 Oct 2024 15:23:38 +0000 Subject: [PATCH] Barycentric interpolation --- bt_ocean/chebyshev.py | 90 +++++++++++++++++++++++++++++++++++++---- bt_ocean/grid.py | 10 +++-- tests/test_chebyshev.py | 28 +++++++++++-- tests/test_grid.py | 20 +++++---- 4 files changed, 127 insertions(+), 21 deletions(-) diff --git a/bt_ocean/chebyshev.py b/bt_ocean/chebyshev.py index c30f23a..e8dc8ca 100644 --- a/bt_ocean/chebyshev.py +++ b/bt_ocean/chebyshev.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp +from enum import Enum, auto from functools import cached_property, partial from .fft import dchebt, idchebt @@ -11,10 +12,19 @@ __all__ = \ [ + "InterpolationMethod", "Chebyshev" ] +class InterpolationMethod(Enum): + """Method to use for interpolation. See :meth:`.Chebyshev.interpolate`. + """ + + BARYCENTRIC = auto() + CLENSHAW = auto() + + class Chebyshev: """Chebyshev pseudospectral utility class. @@ -171,15 +181,68 @@ def from_cheb(self, u_c, *, axis=-1): raise ValueError("Invalid shape") return idchebt(u_c, axis=axis) - def interpolate(self, u, x, *, axis=-1, extrapolate=False): + @staticmethod + @partial(jax.jit, static_argnums=(0, 1), static_argnames="axis") + def _interpolate(idtype, fdtype, u, x, x_c, *, axis=-1): + if axis == -1: + axis = len(u.shape) - 1 + if axis != len(u.shape) - 1: + p = tuple(range(axis)) + tuple(range(axis + 1, len(u.shape))) + (axis,) + p_inv = tuple(range(axis)) + (len(u.shape) - 1,) + tuple(range(axis, len(u.shape) - 1)) + u = jnp.transpose(u, p) + + N = u.shape[-1] - 1 + + # Equation (5.4) in + # Jean-Paul Berrut and Lloyd N. Trefethen, 'Barycentric Lagrange + # interpolation', SIAM Review 46(3), pp. 501--517, 2004, + # https://doi.org/10.1137/S0036144502417715 + # Note that coordinate order reversal might change the sign here, but + # this does not change the result due to cancelling factors + w = jnp.array((-1) ** jnp.arange(N + 1, dtype=idtype), dtype=fdtype) + w = w.at[0].set(0.5 * w[0]) + w = w.at[-1].set(0.5 * w[-1]) + + # Equation (4.2) in + # Jean-Paul Berrut and Lloyd N. Trefethen, 'Barycentric Lagrange + # interpolation', SIAM Review 46(3), pp. 501--517, 2004, + # https://doi.org/10.1137/S0036144502417715 + # When an interpolation point lies exactly on a grid point we + # arbitrarily replace the divide-by-zero with divide-by-one. The + # erroneous result will be discarded below. + X = jnp.array(jnp.tensordot(x, jnp.ones_like(x, shape=x_c.shape), axes=0), dtype=fdtype) + X_exact = jnp.array(abs(X - x_c) < jnp.finfo(fdtype).eps, dtype=idtype) + a = w / jnp.where(X_exact == 1, jnp.ones_like(X), X - x_c) + v = jnp.tensordot(u, a, axes=((-1,), (-1,))) / a.sum(axis=-1) + + # Now handle the case where an interpolation point lies exactly on a + # grid point + v = jnp.where(jnp.tensordot(jnp.ones(u.shape, dtype=idtype), X_exact, axes=((-1,), (-1,))) == 1, + jnp.tensordot(u, X_exact, axes=((-1,), (-1,))), + v) + + if axis != len(u.shape) - 1: + v = jnp.transpose(v, p_inv) + return v + + def interpolate(self, u, x, *, axis=-1, interpolation_method=InterpolationMethod.BARYCENTRIC, extrapolate=False): """Evaluate at the given locations, given an array of grid point values. - Computed by transforming to an expansion in the Chebyshev basis, and - then using the Clenshaw algorithm, using equations (2) and (3) in - - - C. W. Clenshaw, 'A note on the summation of Chebyshev series', - Mathematics of Computation 9, 118--120, 1955 + The `interpolation_method` argument chooses between two interpolation + methods: + + - `InterpolationMethod.BARYCENTRIC`: Barycentric interpolation as + described in Jean-Paul Berrut and Lloyd N. Trefethen, + 'Barycentric Lagrange interpolation', SIAM Review 46(3), + pp. 501--517, 2004, https://doi.org/10.1137/S0036144502417715, + in particular using their equations (4.2) and (5.4). + - `InterpolationMethod.CLENSHAW`: Interpolation performed by + first transforming to an expansion in the Chebyshev basis using + :meth:`.to_cheb`, and then using using the Clenshaw algorithm, + using equations (2) and (3) in C. W. Clenshaw, 'A note on the + summation of Chebyshev series', Mathematics of Computation 9, + 118--120, 1955. Parameters ---------- @@ -190,6 +253,8 @@ def interpolate(self, u, x, *, axis=-1, extrapolate=False): Array of locations. axis : Integral Axis over which to perform the evaluation. + interpolation_method + The interpolation method. extrapolate : bool Whether to allow extrapolation. @@ -200,7 +265,16 @@ def interpolate(self, u, x, *, axis=-1, extrapolate=False): Array of values at the given locations. """ - return self.interpolate_cheb(self.to_cheb(u, axis=axis), x, axis=axis, extrapolate=extrapolate) + if u.shape[axis] != self.N + 1: + raise ValueError("Invalid shape") + if not extrapolate and (abs(x) > 1).any(): + raise ValueError("Out of bounds") + if interpolation_method == InterpolationMethod.BARYCENTRIC: + return self._interpolate(self.idtype, self.fdtype, u, x, self.x, axis=axis) + elif interpolation_method == InterpolationMethod.CLENSHAW: + return self.interpolate_cheb(self.to_cheb(u, axis=axis), x, axis=axis, extrapolate=extrapolate) + else: + raise ValueError(f"Unrecognized interpolation method: {interpolation_method}") @staticmethod @partial(jax.jit, static_argnames="axis") @@ -277,6 +351,8 @@ def _D(idtype, fdtype, x): # Lloyd N. Trefethen, 'Spectral methods in MATLAB', Society for # Industrial and Applied Mathematics, 2000, # https://doi.org/10.1137/1.9780898719598 + # Coordinate order reversal does not change the expression for the + # differentiation matrix. c = jnp.ones_like(x) c = c.at[0].set(2) c = c.at[-1].set(2) diff --git a/bt_ocean/grid.py b/bt_ocean/grid.py index bb609cf..84f9ff3 100644 --- a/bt_ocean/grid.py +++ b/bt_ocean/grid.py @@ -7,7 +7,7 @@ from functools import cached_property from numbers import Real -from .chebyshev import Chebyshev +from .chebyshev import Chebyshev, InterpolationMethod from .precision import default_idtype, default_fdtype __all__ = \ @@ -146,7 +146,7 @@ def Y(self) -> jax.Array: return jnp.outer( jnp.ones(self.N_x + 1, dtype=self.fdtype), self.y) - def interpolate(self, u, x, y, *, extrapolate=False): + def interpolate(self, u, x, y, *, interpolation_method=InterpolationMethod.BARYCENTRIC, extrapolate=False): """Evaluate on a grid. Parameters @@ -158,6 +158,8 @@ def interpolate(self, u, x, y, *, extrapolate=False): :math:`x`-coordinates. y : :class:`jax.Array` :math:`y`-coordinates. + interpolation_method + The interpolation method. See :meth:`.Chebyshev.interpolate`. extrapolate : bool Whether to allow extrapolation. @@ -168,8 +170,8 @@ def interpolate(self, u, x, y, *, extrapolate=False): Array of values on the grid. """ - v = self.cheb_x.interpolate(u, x / self.L_x, axis=0, extrapolate=extrapolate) - v = self.cheb_y.interpolate(v, y / self.L_y, axis=1, extrapolate=extrapolate) + v = self.cheb_x.interpolate(u, x / self.L_x, axis=0, interpolation_method=interpolation_method, extrapolate=extrapolate) + v = self.cheb_y.interpolate(v, y / self.L_y, axis=1, interpolation_method=interpolation_method, extrapolate=extrapolate) return v @cached_property diff --git a/tests/test_chebyshev.py b/tests/test_chebyshev.py index dbcd54d..fd0fbfa 100644 --- a/tests/test_chebyshev.py +++ b/tests/test_chebyshev.py @@ -1,4 +1,4 @@ -from bt_ocean.chebyshev import Chebyshev +from bt_ocean.chebyshev import Chebyshev, InterpolationMethod import jax.numpy as jnp from numpy import exp @@ -33,7 +33,29 @@ def test_chebyshev_basis(alpha, N): @pytest.mark.parametrize("N", tuple(range(3, 11)) + (128, 129)) -def test_chebyshev_interpolation(N): +@pytest.mark.parametrize("interpolation_method", [InterpolationMethod.BARYCENTRIC, + InterpolationMethod.CLENSHAW]) +def test_chebyshev_interpolation_identity(N, interpolation_method): + cheb = Chebyshev(N) + + def u0(x): + return jnp.sqrt(2) - jnp.sqrt(3) * x + jnp.sqrt(5) * x ** 2 - jnp.sqrt(7) * x ** 3 + + def u1(x): + return -jnp.sqrt(11) + jnp.sqrt(13) * x - jnp.sqrt(17) * x ** 2 + jnp.sqrt(19) * x ** 3 + + u = jnp.vstack((u0(cheb.x), u1(cheb.x))).T + x = cheb.x + + v = cheb.interpolate(u, x, axis=0, interpolation_method=interpolation_method) + assert abs(v[:, 0] - u0(x)).max() < 100 * eps() + assert abs(v[:, 1] - u1(x)).max() < 100 * eps() + + +@pytest.mark.parametrize("N", tuple(range(3, 11)) + (128, 129)) +@pytest.mark.parametrize("interpolation_method", [InterpolationMethod.BARYCENTRIC, + InterpolationMethod.CLENSHAW]) +def test_chebyshev_interpolation(N, interpolation_method): cheb = Chebyshev(N) def u0(x): @@ -45,7 +67,7 @@ def u1(x): u = jnp.vstack((u0(cheb.x), u1(cheb.x))).T x = jnp.array((-1 / jnp.pi, 2 / jnp.pi)) - v = cheb.interpolate(u, x, axis=0) + v = cheb.interpolate(u, x, axis=0, interpolation_method=interpolation_method) assert abs(v[:, 0] - u0(x)).max() < 100 * eps() assert abs(v[:, 1] - u1(x)).max() < 100 * eps() diff --git a/tests/test_grid.py b/tests/test_grid.py index 24765f3..10fb867 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -1,4 +1,4 @@ -from bt_ocean.grid import Grid +from bt_ocean.grid import Grid, InterpolationMethod import jax.numpy as jnp from numpy import cbrt @@ -14,7 +14,9 @@ (10, 20), (20, 10), (32, 64)]) -def test_interpolate_identity(L_x, L_y, N_x, N_y): +@pytest.mark.parametrize("interpolation_method", [InterpolationMethod.BARYCENTRIC, + InterpolationMethod.CLENSHAW]) +def test_interpolate_identity(L_x, L_y, N_x, N_y, interpolation_method): grid = Grid(L_x, L_y, N_x, N_y) def u0(X, Y): @@ -24,7 +26,7 @@ def u0(X, Y): u = u0(grid.X, grid.Y) x = grid.x y = grid.y - v = grid.interpolate(u, x, y) + v = grid.interpolate(u, x, y, interpolation_method=interpolation_method) assert abs(v - u0(jnp.outer(x, jnp.ones_like(y)), jnp.outer(jnp.ones_like(x), y))).max() < 100 * eps() @@ -34,7 +36,9 @@ def u0(X, Y): (10, 20), (20, 10), (32, 64)]) -def test_interpolate_uniform(L_x, L_y, N_x, N_y): +@pytest.mark.parametrize("interpolation_method", [InterpolationMethod.BARYCENTRIC, + InterpolationMethod.CLENSHAW]) +def test_interpolate_uniform(L_x, L_y, N_x, N_y, interpolation_method): grid = Grid(L_x, L_y, N_x, N_y) def u0(X, Y): @@ -44,7 +48,7 @@ def u0(X, Y): u = u0(grid.X, grid.Y) x = jnp.linspace(-grid.L_x, grid.L_x, 17) y = jnp.linspace(-grid.L_y, grid.L_y, 19) - v = grid.interpolate(u, x, y) + v = grid.interpolate(u, x, y, interpolation_method=interpolation_method) assert abs(v - u0(jnp.outer(x, jnp.ones_like(y)), jnp.outer(jnp.ones_like(x), y))).max() < 100 * eps() @@ -54,7 +58,9 @@ def u0(X, Y): (10, 20), (20, 10), (32, 64)]) -def test_interpolate_non_uniform(L_x, L_y, N_x, N_y): +@pytest.mark.parametrize("interpolation_method", [InterpolationMethod.BARYCENTRIC, + InterpolationMethod.CLENSHAW]) +def test_interpolate_non_uniform(L_x, L_y, N_x, N_y, interpolation_method): grid = Grid(L_x, L_y, N_x, N_y) def u0(X, Y): @@ -64,5 +70,5 @@ def u0(X, Y): u = u0(grid.X, grid.Y) x = jnp.logspace(-1, 0, 17) * grid.L_x y = jnp.logspace(-2, 0, 19) * grid.L_y - v = grid.interpolate(u, x, y) + v = grid.interpolate(u, x, y, interpolation_method=interpolation_method) assert abs(v - u0(jnp.outer(x, jnp.ones_like(y)), jnp.outer(jnp.ones_like(x), y))).max() < 100 * eps()