diff --git a/devito/types/dimension.py b/devito/types/dimension.py index 784dde499e..152cf4f627 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -1342,14 +1342,16 @@ class CustomDimension(BasicDimension): is_Custom = True - __rkwargs__ = ('symbolic_min', 'symbolic_max', 'symbolic_size', 'parent') + __rkwargs__ = ('symbolic_min', 'symbolic_max', 'symbolic_size', 'parent', + 'local') def __init_finalize__(self, name, symbolic_min=None, symbolic_max=None, - symbolic_size=None, parent=None, **kwargs): + symbolic_size=None, parent=None, local=True, **kwargs): self._symbolic_min = symbolic_min self._symbolic_max = symbolic_max self._symbolic_size = symbolic_size self._parent = parent or BOTTOM + self._local = local super().__init_finalize__(name) @property @@ -1382,6 +1384,10 @@ def spacing(self): else: return self._spacing + @property + def local(self): + return self._local + @property def bound_symbols(self): ret = {self.symbolic_min, self.symbolic_max, self.symbolic_size} @@ -1389,6 +1395,10 @@ def bound_symbols(self): ret.update(self.parent.bound_symbols) return frozenset(i for i in ret if i.is_Symbol) + @property + def _maybe_distributed(self): + return not self.local + @cached_property def _defines(self): ret = frozenset({self}) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 9d2032c472..8b5913fffc 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -8,7 +8,8 @@ Dimension, ConditionalDimension, div, solve, diag, grad, SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm, inner, configuration, switchconfig, generic_derivative, - PrecomputedSparseFunction, DefaultDimension, Buffer) + PrecomputedSparseFunction, DefaultDimension, Buffer, + CustomDimension) from devito.arch.compiler import OneapiCompiler from devito.data import LEFT, RIGHT from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols, @@ -1009,7 +1010,7 @@ def test_avoid_haloupdate_if_distr_but_sequential(self, mode): assert len(calls) == 0 @pytest.mark.parallel(mode=1) - def test_avoid_haloupdate_with_subdims(self, mode): + def test_avoid_haloupdate_with_local_subdims(self, mode): grid = Grid(shape=(4,)) x = grid.dimensions[0] t = grid.stepping_dim @@ -1034,6 +1035,22 @@ def test_avoid_haloupdate_with_subdims(self, mode): calls = FindNodes(Call).visit(op) assert len(calls) == 1 + @pytest.mark.parallel(mode=1) + def test_avoid_haloupdate_with_local_customdim(self, mode): + grid = Grid(shape=(10, 10)) + x, y = grid.dimensions + + d = CustomDimension(name='d', symbolic_min=1, symbolic_max=3, parent=y) + + u = TimeFunction(name='u', grid=grid, space_order=4) + + eq = Eq(u.forward.subs(y, -d), u.subs(y, d - 1) + 1) + + op = Operator(eq) + + calls = FindNodes(Call).visit(op) + assert len(calls) == 0 + @pytest.mark.parallel(mode=1) def test_avoid_haloupdate_with_constant_index(self, mode): grid = Grid(shape=(4,))