Skip to content

Commit

Permalink
Merge pull request #2369 from devitocodes/expand-time
Browse files Browse the repository at this point in the history
api: Always expand time derivatives
  • Loading branch information
mloubout authored May 7, 2024
2 parents 0e0e0ac + f8aa384 commit 84d981e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
5 changes: 5 additions & 0 deletions devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ def generic_derivative(expr, dim, fd_order, deriv_order, matvec=direct, x0=None,

def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coefficients,
expand):
# Always expand time derivatives to avoid issue with buffering and streaming.
# Time derivative are almost always short stencils and won't benefit from
# unexpansion in the rare case the derivative is not evaluated for time stepping.
expand = dim.is_Time or expand

# The stencil indices
indices, x0 = generate_indices(expr, dim, fd_order, side=side, matvec=matvec,
x0=x0)
Expand Down
25 changes: 24 additions & 1 deletion tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from devito.arch.compiler import OneapiCompiler
from devito.ir import Expression, FindNodes, FindSymbols
from devito.parameters import switchconfig, configuration
from devito.types import Symbol
from devito.types import Symbol, Dimension


class TestLoopScheduling(object):
Expand Down Expand Up @@ -330,6 +330,29 @@ def test_redundant_derivatives(self):

op.cfunction

def test_buffering_timestencil(self):
grid = Grid((11, 11))
so = 4
nt = 11

u = TimeFunction(name="u", grid=grid, space_order=so, time_order=2, save=nt)
v = TimeFunction(name="v", grid=grid, space_order=so, time_order=2)

g = Function(name="g", grid=grid, space_order=so)

# Make sure operator builds with buffering
op = Operator([Eq(g, g + u.dt*v.dx + u.dx2)],
opt=('buffering', 'streaming', {'expand': False}))

exprs = FindNodes(Expression).visit(op)
dims = [d for i in FindSymbols().visit(exprs) for d in i.dimensions
if isinstance(d, Dimension)]

# Should only be two stencil dimension for .dx and .dx2
assert len([d for d in dims if d.is_Stencil]) == 2
# Should only be one buffer dimension
assert len([d for d in dims if d.is_Custom]) == 1


class Test2Pass(object):

Expand Down

0 comments on commit 84d981e

Please sign in to comment.