Skip to content

Commit

Permalink
api: enforce interpolation radius to be smaller than any input space …
Browse files Browse the repository at this point in the history
…order
  • Loading branch information
mloubout committed Oct 12, 2023
1 parent 074df11 commit c8fd8bd
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 18 deletions.
17 changes: 16 additions & 1 deletion devito/operations/interpolators.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
from functools import wraps

import sympy
from cached_property import cached_property

from devito.finite_differences.differentiable import Mul
from devito.finite_differences.elementary import floor
from devito.symbolics import retrieve_function_carriers, INT
from devito.symbolics import retrieve_function_carriers, retrieve_functions, INT
from devito.tools import as_tuple, flatten
from devito.types import (ConditionalDimension, Eq, Inc, Evaluable, Symbol,
CustomDimension)
Expand All @@ -14,6 +15,18 @@
__all__ = ['LinearInterpolator', 'PrecomputedInterpolator']


def check_radius(func):
@wraps(func)
def wrapper(interp, *args, **kwargs):
r = interp.sfunction.r
funcs = set().union(*[retrieve_functions(a) for a in args])
so = min({f.space_order for f in funcs if not f.is_SparseFunction} or {r})
if so < r:
raise ValueError("Space order %d smaller than interpolation r %d" % (so, r))
return func(interp, *args, **kwargs)
return wrapper


class UnevaluatedSparseOperation(sympy.Expr, Evaluable):

"""
Expand Down Expand Up @@ -209,6 +222,7 @@ def _interp_idx(self, variables, implicit_dims=None):

return idx_subs, temps

@check_radius
def interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
"""
Generate equations interpolating an arbitrary expression into ``self``.
Expand All @@ -226,6 +240,7 @@ def interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
"""
return Interpolation(expr, increment, implicit_dims, self_subs, self)

@check_radius
def inject(self, field, expr, implicit_dims=None):
"""
Generate equations injecting an arbitrary expression into a field.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def test_scheduling(self):
"""
grid = Grid(shape=(11, 11))

u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0)
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1)
sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5)

eqns = [Eq(u.forward, u + 1)]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,7 +1424,7 @@ def test_empty_arrays(self):
"""
grid = Grid(shape=(4, 4), extent=(3.0, 3.0))

f = TimeFunction(name='f', grid=grid, space_order=0)
f = TimeFunction(name='f', grid=grid, space_order=1)
f.data[:] = 1.
sf1 = SparseTimeFunction(name='sf1', grid=grid, npoint=0, nt=10)
sf2 = SparseTimeFunction(name='sf2', grid=grid, npoint=0, nt=10,
Expand Down
24 changes: 17 additions & 7 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,15 @@ def test_precomputed_interpolation(r):
origin = (0, 0)

grid = Grid(shape=shape, origin=origin)
r = 2 # Constant for linear interpolation
# because we interpolate across 2 neighbouring points in each dimension

def init(data):
# This is data with halo so need to shift to match the m.data expectations
for i in range(data.shape[0]):
for j in range(data.shape[1]):
data[i, j] = sin(grid.spacing[0]*i) + sin(grid.spacing[1]*j)
data[i, j] = sin(grid.spacing[0]*(i-r)) + sin(grid.spacing[1]*(j-r))
return data

m = Function(name='m', grid=grid, initializer=init, space_order=0)
m = Function(name='m', grid=grid, initializer=init, space_order=r)

gridpoints, interpolation_coeffs = precompute_linear_interpolation(points,
grid, origin,
Expand Down Expand Up @@ -154,10 +153,8 @@ def test_precomputed_interpolation_time(r):
origin = (0, 0)

grid = Grid(shape=shape, origin=origin)
r = 2 # Constant for linear interpolation
# because we interpolate across 2 neighbouring points in each dimension

u = TimeFunction(name='u', grid=grid, space_order=0, save=5)
u = TimeFunction(name='u', grid=grid, space_order=r, save=5)
for it in range(5):
u.data[it, :] = it

Expand Down Expand Up @@ -761,3 +758,16 @@ def test_inject_function():
for i in [0, 1, 3, 4]:
for j in [0, 1, 3, 4]:
assert u.data[1, i, j] == 0


def test_interpolation_radius():
nt = 11

grid = Grid(shape=(5, 5))
u = TimeFunction(name="u", grid=grid, space_order=0)
src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1)
try:
src.interpolate(u)
assert False
except ValueError:
assert True
6 changes: 3 additions & 3 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,7 @@ def test_injection_wodup(self):
"""
grid = Grid(shape=(4, 4), extent=(3.0, 3.0))

f = Function(name='f', grid=grid, space_order=0)
f = Function(name='f', grid=grid, space_order=1)
f.data[:] = 0.
coords = np.array([(0.5, 0.5), (0.5, 2.5), (2.5, 0.5), (2.5, 2.5)])
sf = SparseFunction(name='sf', grid=grid, npoint=len(coords), coordinates=coords)
Expand Down Expand Up @@ -1536,7 +1536,7 @@ def test_injection_wodup_wtime(self):
grid = Grid(shape=(4, 4), extent=(3.0, 3.0))

save = 3
f = TimeFunction(name='f', grid=grid, save=save, space_order=0)
f = TimeFunction(name='f', grid=grid, save=save, space_order=1)
f.data[:] = 0.
coords = np.array([(0.5, 0.5), (0.5, 2.5), (2.5, 0.5), (2.5, 2.5)])
sf = SparseTimeFunction(name='sf', grid=grid, nt=save,
Expand Down Expand Up @@ -1611,7 +1611,7 @@ def test_injection_dup(self):
def test_interpolation_wodup(self):
grid = Grid(shape=(4, 4), extent=(3.0, 3.0))

f = Function(name='f', grid=grid, space_order=0)
f = Function(name='f', grid=grid, space_order=1)
f.data[:] = 4.
coords = [(0.5, 0.5), (0.5, 2.5), (2.5, 0.5), (2.5, 2.5)]
sf = SparseFunction(name='sf', grid=grid, npoint=len(coords), coordinates=coords)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def test_sparsefunction_inject(self):
Test injection of a SparseFunction into a Function
"""
grid = Grid(shape=(11, 11))
u = Function(name='u', grid=grid, space_order=0)
u = Function(name='u', grid=grid, space_order=1)

sf1 = SparseFunction(name='s', grid=grid, npoint=1)
op = Operator(sf1.inject(u, expr=sf1))
Expand All @@ -542,7 +542,7 @@ def test_sparsefunction_interp(self):
Test interpolation of a SparseFunction from a Function
"""
grid = Grid(shape=(11, 11))
u = Function(name='u', grid=grid, space_order=0)
u = Function(name='u', grid=grid, space_order=1)

sf1 = SparseFunction(name='s', grid=grid, npoint=1)
op = Operator(sf1.interpolate(u))
Expand All @@ -563,7 +563,7 @@ def test_sparsetimefunction_interp(self):
Test injection of a SparseTimeFunction into a TimeFunction
"""
grid = Grid(shape=(11, 11))
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0)
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1)

sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5)
op = Operator(sf1.interpolate(u))
Expand All @@ -586,7 +586,7 @@ def test_sparsetimefunction_inject(self):
Test injection of a SparseTimeFunction from a TimeFunction
"""
grid = Grid(shape=(11, 11))
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0)
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1)

sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5)
op = Operator(sf1.inject(u, expr=3*sf1))
Expand All @@ -611,7 +611,7 @@ def test_sparsetimefunction_inject_dt(self):
Test injection of the time deivative of a SparseTimeFunction into a TimeFunction
"""
grid = Grid(shape=(11, 11))
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0)
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1)

sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5, time_order=2)

Expand Down

0 comments on commit c8fd8bd

Please sign in to comment.