Skip to content

Commit

Permalink
compiler: Fix unexpansion w custom coeffs
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Oct 20, 2023
1 parent dd1b154 commit d62b080
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
7 changes: 6 additions & 1 deletion devito/finite_differences/coefficients.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from cached_property import cached_property

from devito.finite_differences import generate_indices
from devito.finite_differences import Weights, generate_indices
from devito.finite_differences.tools import numeric_weights, symbolic_weights
from devito.tools import filter_ordered, as_tuple

Expand Down Expand Up @@ -268,8 +268,13 @@ def generate_subs(deriv_order, function, index):
return subs

# Determine which 'rules' are missing

sym = get_sym(functions)
terms = obj.find(sym)
for i in obj.find(Weights):
for w in i.weights:
terms.update(w.find(sym))

args_present = filter_ordered(term.args[1:] for term in terms)

subs = obj.substitutions
Expand Down
15 changes: 15 additions & 0 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,21 @@ def spacings(self):

weights = Array.initvalue

def _xreplace(self, rule):
if self in rule:
return rule[self], True
elif not rule:
return self, False
else:
try:
weights, flags = zip(*[i._xreplace(rule) for i in self.weights])
if any(flags):
return self.func(initvalue=weights, function=None), True
except AttributeError:
# `float` weights
pass
return super()._xreplace(rule)


class IndexDerivative(IndexSum):

Expand Down
17 changes: 17 additions & 0 deletions tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ def test_backward_dt2(self):
assert_structure(op, ['t,x,y'], 't,x,y')


class TestSymbolicCoefficients(object):

def test_fallback_to_default(self):
grid = Grid(shape=(8, 8, 8))

u = TimeFunction(name='u', grid=grid, coefficients='symbolic',
space_order=4, time_order=2)

eq = Eq(u.forward, u.dx2 + 1)

op = Operator(eq, opt=('advanced', {'expand': False}))

# Ensure all symbols have been resolved
op.arguments(dt=1, time_M=10)
op.cfunction


class Test1Pass(object):

def test_v0(self):
Expand Down

0 comments on commit d62b080

Please sign in to comment.