diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index 61e184fbf1..365d90bc99 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -8,13 +8,14 @@ from .differentiable import Differentiable from .tools import direct, transpose from .rsfd import d45 -from devito.tools import as_mapper, as_tuple, filter_ordered, frozendict, is_integer +from devito.tools import (as_mapper, as_tuple, filter_ordered, frozendict, is_integer, + Reconstructable) from devito.types.utils import DimensionTuple __all__ = ['Derivative'] -class Derivative(sympy.Derivative, Differentiable): +class Derivative(sympy.Derivative, Differentiable, Reconstructable): """ An unevaluated Derivative, which carries metadata (Dimensions, @@ -86,7 +87,7 @@ class Derivative(sympy.Derivative, Differentiable): _fd_priority = 3 - __rargs__ = ('expr', 'dims') + __rargs__ = ('expr', '*dims') __rkwargs__ = ('side', 'deriv_order', 'fd_order', 'transpose', '_ppsubs', 'x0', 'method') @@ -201,7 +202,7 @@ def _process_x0(cls, dims, **kwargs): # Only given a value _x0 = kwargs.get('x0') assert len(dims) == 1 or _x0 is None - if _x0 is not None: + if _x0 is not None and _x0 is not dims[0]: x0 = frozendict({dims[0]: _x0}) else: x0 = frozendict({}) @@ -215,8 +216,7 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None): fd_order = fd_order or self._fd_order side = side or self._side method = method or self._method - return self._new_from_self(fd_order=fd_order, side=side, x0=_x0, - method=method) + return self._rebuild(fd_order=fd_order, side=side, x0=_x0, method=method) if side is not None: raise TypeError("Side only supported for first order single" @@ -230,18 +230,13 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None): except AttributeError: raise TypeError("Multi-dimensional Derivative, input expected as a dict") - return self._new_from_self(fd_order=_fd_order, x0=_x0) + return self._rebuild(fd_order=_fd_order, x0=_x0) - def _new_from_self(self, **kwargs): - expr = kwargs.pop('expr', self.expr) - _kwargs = {'deriv_order': self.deriv_order, 'fd_order': self.fd_order, - 'side': self.side, 'transpose': self.transpose, 'subs': self._ppsubs, - 'x0': self.x0, 'preprocessed': True, 'method': self.method} - _kwargs.update(**kwargs) - return Derivative(expr, *self.dims, **_kwargs) + def _rebuild(self, *args, **kwargs): + kwargs['preprocessed'] = True + return super()._rebuild(*args, **kwargs) - def func(self, expr, *args, **kwargs): - return self._new_from_self(expr=expr, **kwargs) + func = _rebuild def _subs(self, old, new, **hints): # Basic case @@ -251,7 +246,7 @@ def _subs(self, old, new, **hints): if self.expr.has(old): newexpr = self.expr._subs(old, new, **hints) try: - return self._new_from_self(expr=newexpr) + return self._rebuild(expr=newexpr) except ValueError: # Expr replacement leads to non-differentiable expression # e.g `f.dx.subs(f: 1) = 1.dx = 0` @@ -260,7 +255,7 @@ def _subs(self, old, new, **hints): # In case `x0` was passed as a substitution instead of `(x0=` if str(old) == 'x0': - return self._new_from_self(x0={self.dims[0]: new}) + return self._rebuild(x0={self.dims[0]: new}) # Trying to substitute by another derivative with different metadata # Only need to check if is a Derivative since one for the cases above would @@ -289,13 +284,11 @@ def _xreplace(self, subs): return new, True subs = self._ppsubs + (subs,) # Postponed substitutions - return self._new_from_self(subs=subs), True + return self._rebuild(subs=subs), True @cached_property def _metadata(self): - state = list(self.__rargs__ + self.__rkwargs__) - state.remove('expr') - ret = [getattr(self, i) for i in state] + ret = [self.dims] + [getattr(self, i) for i in self.__rkwargs__] ret.append(self.expr.staggered or (None,)) return tuple(ret) @@ -348,7 +341,7 @@ def T(self): else: adjoint = direct - return self._new_from_self(transpose=adjoint) + return self._rebuild(transpose=adjoint) def _eval_at(self, func): """ @@ -360,6 +353,10 @@ def _eval_at(self, func): # do not overwrite it if self.x0 or self.side is not None or func.function is self.expr.function: return self + # For basic equation of the form f = Derivative(g, ...) we can just + # compare staggering + if self.expr.staggered == func.staggered: + return self x0 = func.indices_ref._getters if self.expr.is_Add: @@ -370,19 +367,19 @@ def _eval_at(self, func): mapper = as_mapper(self.expr._args_diff, lambda i: i.staggered) args = [self.expr.func(*v) for v in mapper.values()] args.extend([a for a in self.expr.args if a not in self.expr._args_diff]) - args = [self._new_from_self(expr=a, x0=x0) for a in args] + args = [self._rebuild(expr=a, x0=x0) for a in args] return self.expr.func(*args) elif self.expr.is_Mul: # For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear # in most equation with div(a * u) for example. The expression is re-centered # at the highest priority index (see _gather_for_diff) to compute the # derivative at x0. - return self._new_from_self(x0=x0, expr=self.expr._gather_for_diff) + return self._rebuild(expr=self.expr._gather_for_diff, x0=x0) else: # For every other cases, that has more functions or more complexe arithmetic, # there is not actual way to decide what to do so it’s as safe to use # the expression as is. - return self._new_from_self(x0=x0) + return self._rebuild(x0=x0) def _evaluate(self, **kwargs): # Evaluate finite-difference. diff --git a/devito/passes/equations/linearity.py b/devito/passes/equations/linearity.py index 66390f821e..c07068996b 100644 --- a/devito/passes/equations/linearity.py +++ b/devito/passes/equations/linearity.py @@ -103,7 +103,7 @@ def _(expr, mapper, nn_derivs=None): @aggregate_coeffs.register(sympy.Derivative) def _(expr, mapper, nn_derivs=None): # Opens up a new derivative scope, so do not propagate `nn_derivs` - args = [aggregate_coeffs(a, mapper) for a in expr.args] + args = [aggregate_coeffs(expr.expr, mapper)] expr = reuse_if_untouched(expr, args) return expr @@ -164,10 +164,10 @@ def _(expr, mapper, nn_derivs=None): return expr if len(derivs) == 1 and with_deriv is derivs[0]: - expr = with_deriv._new_from_self(expr=expr.func(*hope_coeffs, with_deriv.expr)) + expr = with_deriv._rebuild(expr=expr.func(*hope_coeffs, with_deriv.expr)) else: others = [expr.func(*hope_coeffs, a) for a in others] - derivs = [a._new_from_self(expr=expr.func(*hope_coeffs, a.expr)) for a in derivs] + derivs = [a._rebuild(expr=expr.func(*hope_coeffs, a.expr)) for a in derivs] expr = with_deriv.func(*(derivs + others)) return expr @@ -190,6 +190,14 @@ def _(expr): return expr +@factorize_derivatives.register(sympy.Derivative) +def _(expr): + args = [factorize_derivatives(expr.expr)] + expr = reuse_if_untouched(expr, args) + + return expr + + @factorize_derivatives.register(sympy.Add) def _(expr): args = [factorize_derivatives(a) for a in expr.args] @@ -216,7 +224,7 @@ def _(expr): if len(v) == 1: args.append(c) else: - args.append(c._new_from_self(expr=expr.func(*[i.expr for i in v]))) + args.append(c._rebuild(expr=expr.func(*[i.expr for i in v]))) expr = expr.func(*args) return expr diff --git a/devito/tools/abc.py b/devito/tools/abc.py index 7213a90336..b943256979 100644 --- a/devito/tools/abc.py +++ b/devito/tools/abc.py @@ -143,7 +143,7 @@ def __init__(self, a, b, c=4): kwargs.update({i: getattr(self, i) for i in self.__rkwargs__ if i not in kwargs}) - # Should we use a constum reconstructor? + # Should we use a custom reconstructor? try: cls = self._rcls except AttributeError: diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index 05f9c40d6d..689b890837 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -713,6 +713,16 @@ def test_deriv_spec(self): assert dxy0.x0 == {y: y+y.spacing/2} assert dxy02.x0 == {x: x+x.spacing/2} + def test_deriv_stagg_plain(self): + grid = Grid((11, 11)) + x, y = grid.dimensions + f1 = Function(name="f1", grid=grid, space_order=2, staggered=NODE) + f2 = Function(name="f2", grid=grid, space_order=2, staggered=NODE) + + eq0 = Eq(f1, f2.laplace).evaluate + assert eq0.rhs == f2.laplace.evaluate + assert eq0.rhs != 0 + class TestTwoStageEvaluation: