Skip to content

Commit

Permalink
Barycentric interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmaddison committed Oct 28, 2024
1 parent 2d5cce0 commit 79d3f10
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 21 deletions.
90 changes: 83 additions & 7 deletions bt_ocean/chebyshev.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,27 @@
import jax
import jax.numpy as jnp

from enum import Enum, auto
from functools import cached_property, partial

from .fft import dchebt, idchebt
from .precision import default_idtype, default_fdtype

__all__ = \
[
"InterpolationMethod",
"Chebyshev"
]


class InterpolationMethod(Enum):
"""Method to use for interpolation. See :meth:`.Chebyshev.interpolate`.
"""

BARYCENTRIC = auto()
CLENSHAW = auto()


class Chebyshev:
"""Chebyshev pseudospectral utility class.
Expand Down Expand Up @@ -171,15 +181,68 @@ def from_cheb(self, u_c, *, axis=-1):
raise ValueError("Invalid shape")
return idchebt(u_c, axis=axis)

def interpolate(self, u, x, *, axis=-1, extrapolate=False):
@staticmethod
@partial(jax.jit, static_argnums=(0, 1), static_argnames="axis")
def _interpolate(idtype, fdtype, u, x, x_c, *, axis=-1):
if axis == -1:
axis = len(u.shape) - 1
if axis != len(u.shape) - 1:
p = tuple(range(axis)) + tuple(range(axis + 1, len(u.shape))) + (axis,)
p_inv = tuple(range(axis)) + (len(u.shape) - 1,) + tuple(range(axis, len(u.shape) - 1))
u = jnp.transpose(u, p)

N = u.shape[-1] - 1

# Equation (5.4) in
# Jean-Paul Berrut and Lloyd N. Trefethen, 'Barycentric Lagrange
# interpolation', SIAM Review 46(3), pp. 501--517, 2004,
# https://doi.org/10.1137/S0036144502417715
# Note that coordinate order reversal might change the sign here, but
# this does not change the result due to cancelling factors
w = jnp.array((-1) ** jnp.arange(N + 1, dtype=idtype), dtype=fdtype)
w = w.at[0].set(0.5 * w[0])
w = w.at[-1].set(0.5 * w[-1])

# Equation (4.2) in
# Jean-Paul Berrut and Lloyd N. Trefethen, 'Barycentric Lagrange
# interpolation', SIAM Review 46(3), pp. 501--517, 2004,
# https://doi.org/10.1137/S0036144502417715
# When an interpolation point lies exactly on a grid point we
# arbitrarily replace the divide-by-zero with divide-by-one. The
# erroneous result will be discarded below.
X = jnp.array(jnp.tensordot(x, jnp.ones_like(x, shape=x_c.shape), axes=0), dtype=fdtype)
X_exact = jnp.array(abs(X - x_c) < jnp.finfo(fdtype).eps, dtype=idtype)
a = w / jnp.where(X_exact == 1, jnp.ones_like(X), X - x_c)
v = jnp.tensordot(u, a, axes=((-1,), (-1,))) / a.sum(axis=-1)

# Now handle the case where an interpolation point lies exactly on a
# grid point
v = jnp.where(jnp.tensordot(jnp.ones(u.shape, dtype=idtype), X_exact, axes=((-1,), (-1,))) == 1,
jnp.tensordot(u, X_exact, axes=((-1,), (-1,))),
v)

if axis != len(u.shape) - 1:
v = jnp.transpose(v, p_inv)
return v

def interpolate(self, u, x, *, axis=-1, interpolation_method=InterpolationMethod.BARYCENTRIC, extrapolate=False):
"""Evaluate at the given locations, given an array of grid point
values.
Computed by transforming to an expansion in the Chebyshev basis, and
then using the Clenshaw algorithm, using equations (2) and (3) in
- C. W. Clenshaw, 'A note on the summation of Chebyshev series',
Mathematics of Computation 9, 118--120, 1955
The `interpolation_method` argument chooses between two interpolation
methods:
- `InterpolationMethod.BARYCENTRIC`: Barycentric interpolation as
described in Jean-Paul Berrut and Lloyd N. Trefethen,
'Barycentric Lagrange interpolation', SIAM Review 46(3),
pp. 501--517, 2004, https://doi.org/10.1137/S0036144502417715,
in particular using their equations (4.2) and (5.4).
- `InterpolationMethod.CLENSHAW`: Interpolation performed by
first transforming to an expansion in the Chebyshev basis using
:meth:`.to_cheb`, and then using using the Clenshaw algorithm,
using equations (2) and (3) in C. W. Clenshaw, 'A note on the
summation of Chebyshev series', Mathematics of Computation 9,
118--120, 1955.
Parameters
----------
Expand All @@ -190,6 +253,8 @@ def interpolate(self, u, x, *, axis=-1, extrapolate=False):
Array of locations.
axis : Integral
Axis over which to perform the evaluation.
interpolation_method
The interpolation method.
extrapolate : bool
Whether to allow extrapolation.
Expand All @@ -200,7 +265,16 @@ def interpolate(self, u, x, *, axis=-1, extrapolate=False):
Array of values at the given locations.
"""

return self.interpolate_cheb(self.to_cheb(u, axis=axis), x, axis=axis, extrapolate=extrapolate)
if u.shape[axis] != self.N + 1:
raise ValueError("Invalid shape")
if not extrapolate and (abs(x) > 1).any():
raise ValueError("Out of bounds")
if interpolation_method == InterpolationMethod.BARYCENTRIC:
return self._interpolate(self.idtype, self.fdtype, u, x, self.x, axis=axis)
elif interpolation_method == InterpolationMethod.CLENSHAW:
return self.interpolate_cheb(self.to_cheb(u, axis=axis), x, axis=axis, extrapolate=extrapolate)
else:
raise ValueError(f"Unrecognized interpolation method: {interpolation_method}")

@staticmethod
@partial(jax.jit, static_argnames="axis")
Expand Down Expand Up @@ -277,6 +351,8 @@ def _D(idtype, fdtype, x):
# Lloyd N. Trefethen, 'Spectral methods in MATLAB', Society for
# Industrial and Applied Mathematics, 2000,
# https://doi.org/10.1137/1.9780898719598
# Coordinate order reversal does not change the expression for the
# differentiation matrix.
c = jnp.ones_like(x)
c = c.at[0].set(2)
c = c.at[-1].set(2)
Expand Down
10 changes: 6 additions & 4 deletions bt_ocean/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import cached_property
from numbers import Real

from .chebyshev import Chebyshev
from .chebyshev import Chebyshev, InterpolationMethod
from .precision import default_idtype, default_fdtype

__all__ = \
Expand Down Expand Up @@ -146,7 +146,7 @@ def Y(self) -> jax.Array:
return jnp.outer(
jnp.ones(self.N_x + 1, dtype=self.fdtype), self.y)

def interpolate(self, u, x, y, *, extrapolate=False):
def interpolate(self, u, x, y, *, interpolation_method=InterpolationMethod.BARYCENTRIC, extrapolate=False):
"""Evaluate on a grid.
Parameters
Expand All @@ -158,6 +158,8 @@ def interpolate(self, u, x, y, *, extrapolate=False):
:math:`x`-coordinates.
y : :class:`jax.Array`
:math:`y`-coordinates.
interpolation_method
The interpolation method. See :meth:`.Chebyshev.interpolate`.
extrapolate : bool
Whether to allow extrapolation.
Expand All @@ -168,8 +170,8 @@ def interpolate(self, u, x, y, *, extrapolate=False):
Array of values on the grid.
"""

v = self.cheb_x.interpolate(u, x / self.L_x, axis=0, extrapolate=extrapolate)
v = self.cheb_y.interpolate(v, y / self.L_y, axis=1, extrapolate=extrapolate)
v = self.cheb_x.interpolate(u, x / self.L_x, axis=0, interpolation_method=interpolation_method, extrapolate=extrapolate)
v = self.cheb_y.interpolate(v, y / self.L_y, axis=1, interpolation_method=interpolation_method, extrapolate=extrapolate)
return v

@cached_property
Expand Down
28 changes: 25 additions & 3 deletions tests/test_chebyshev.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from bt_ocean.chebyshev import Chebyshev
from bt_ocean.chebyshev import Chebyshev, InterpolationMethod

import jax.numpy as jnp
from numpy import exp
Expand Down Expand Up @@ -33,7 +33,29 @@ def test_chebyshev_basis(alpha, N):


@pytest.mark.parametrize("N", tuple(range(3, 11)) + (128, 129))
def test_chebyshev_interpolation(N):
@pytest.mark.parametrize("interpolation_method", [InterpolationMethod.BARYCENTRIC,
InterpolationMethod.CLENSHAW])
def test_chebyshev_interpolation_identity(N, interpolation_method):
cheb = Chebyshev(N)

def u0(x):
return jnp.sqrt(2) - jnp.sqrt(3) * x + jnp.sqrt(5) * x ** 2 - jnp.sqrt(7) * x ** 3

def u1(x):
return -jnp.sqrt(11) + jnp.sqrt(13) * x - jnp.sqrt(17) * x ** 2 + jnp.sqrt(19) * x ** 3

u = jnp.vstack((u0(cheb.x), u1(cheb.x))).T
x = cheb.x

v = cheb.interpolate(u, x, axis=0, interpolation_method=interpolation_method)
assert abs(v[:, 0] - u0(x)).max() < 100 * eps()
assert abs(v[:, 1] - u1(x)).max() < 100 * eps()


@pytest.mark.parametrize("N", tuple(range(3, 11)) + (128, 129))
@pytest.mark.parametrize("interpolation_method", [InterpolationMethod.BARYCENTRIC,
InterpolationMethod.CLENSHAW])
def test_chebyshev_interpolation(N, interpolation_method):
cheb = Chebyshev(N)

def u0(x):
Expand All @@ -45,7 +67,7 @@ def u1(x):
u = jnp.vstack((u0(cheb.x), u1(cheb.x))).T
x = jnp.array((-1 / jnp.pi, 2 / jnp.pi))

v = cheb.interpolate(u, x, axis=0)
v = cheb.interpolate(u, x, axis=0, interpolation_method=interpolation_method)
assert abs(v[:, 0] - u0(x)).max() < 100 * eps()
assert abs(v[:, 1] - u1(x)).max() < 100 * eps()

Expand Down
20 changes: 13 additions & 7 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from bt_ocean.grid import Grid
from bt_ocean.grid import Grid, InterpolationMethod

import jax.numpy as jnp
from numpy import cbrt
Expand All @@ -14,7 +14,9 @@
(10, 20),
(20, 10),
(32, 64)])
def test_interpolate_identity(L_x, L_y, N_x, N_y):
@pytest.mark.parametrize("interpolation_method", [InterpolationMethod.BARYCENTRIC,
InterpolationMethod.CLENSHAW])
def test_interpolate_identity(L_x, L_y, N_x, N_y, interpolation_method):
grid = Grid(L_x, L_y, N_x, N_y)

def u0(X, Y):
Expand All @@ -24,7 +26,7 @@ def u0(X, Y):
u = u0(grid.X, grid.Y)
x = grid.x
y = grid.y
v = grid.interpolate(u, x, y)
v = grid.interpolate(u, x, y, interpolation_method=interpolation_method)
assert abs(v - u0(jnp.outer(x, jnp.ones_like(y)), jnp.outer(jnp.ones_like(x), y))).max() < 100 * eps()


Expand All @@ -34,7 +36,9 @@ def u0(X, Y):
(10, 20),
(20, 10),
(32, 64)])
def test_interpolate_uniform(L_x, L_y, N_x, N_y):
@pytest.mark.parametrize("interpolation_method", [InterpolationMethod.BARYCENTRIC,
InterpolationMethod.CLENSHAW])
def test_interpolate_uniform(L_x, L_y, N_x, N_y, interpolation_method):
grid = Grid(L_x, L_y, N_x, N_y)

def u0(X, Y):
Expand All @@ -44,7 +48,7 @@ def u0(X, Y):
u = u0(grid.X, grid.Y)
x = jnp.linspace(-grid.L_x, grid.L_x, 17)
y = jnp.linspace(-grid.L_y, grid.L_y, 19)
v = grid.interpolate(u, x, y)
v = grid.interpolate(u, x, y, interpolation_method=interpolation_method)
assert abs(v - u0(jnp.outer(x, jnp.ones_like(y)), jnp.outer(jnp.ones_like(x), y))).max() < 100 * eps()


Expand All @@ -54,7 +58,9 @@ def u0(X, Y):
(10, 20),
(20, 10),
(32, 64)])
def test_interpolate_non_uniform(L_x, L_y, N_x, N_y):
@pytest.mark.parametrize("interpolation_method", [InterpolationMethod.BARYCENTRIC,
InterpolationMethod.CLENSHAW])
def test_interpolate_non_uniform(L_x, L_y, N_x, N_y, interpolation_method):
grid = Grid(L_x, L_y, N_x, N_y)

def u0(X, Y):
Expand All @@ -64,5 +70,5 @@ def u0(X, Y):
u = u0(grid.X, grid.Y)
x = jnp.logspace(-1, 0, 17) * grid.L_x
y = jnp.logspace(-2, 0, 19) * grid.L_y
v = grid.interpolate(u, x, y)
v = grid.interpolate(u, x, y, interpolation_method=interpolation_method)
assert abs(v - u0(jnp.outer(x, jnp.ones_like(y)), jnp.outer(jnp.ones_like(x), y))).max() < 100 * eps()

0 comments on commit 79d3f10

Please sign in to comment.