Skip to content

Commit

Permalink
compiler: Default CustomDimension.local to True
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Dec 27, 2024
1 parent 061ef8e commit 3351867
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
14 changes: 12 additions & 2 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,14 +1342,16 @@ class CustomDimension(BasicDimension):

is_Custom = True

__rkwargs__ = ('symbolic_min', 'symbolic_max', 'symbolic_size', 'parent')
__rkwargs__ = ('symbolic_min', 'symbolic_max', 'symbolic_size', 'parent',
'local')

def __init_finalize__(self, name, symbolic_min=None, symbolic_max=None,
symbolic_size=None, parent=None, **kwargs):
symbolic_size=None, parent=None, local=True, **kwargs):
self._symbolic_min = symbolic_min
self._symbolic_max = symbolic_max
self._symbolic_size = symbolic_size
self._parent = parent or BOTTOM
self._local = local
super().__init_finalize__(name)

@property
Expand Down Expand Up @@ -1382,13 +1384,21 @@ def spacing(self):
else:
return self._spacing

@property
def local(self):
return self._local

@property
def bound_symbols(self):
ret = {self.symbolic_min, self.symbolic_max, self.symbolic_size}
if self.is_Derived:
ret.update(self.parent.bound_symbols)
return frozenset(i for i in ret if i.is_Symbol)

@property
def _maybe_distributed(self):
return not self.local

@cached_property
def _defines(self):
ret = frozenset({self})
Expand Down
21 changes: 19 additions & 2 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
Dimension, ConditionalDimension, div, solve, diag, grad,
SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm,
inner, configuration, switchconfig, generic_derivative,
PrecomputedSparseFunction, DefaultDimension, Buffer)
PrecomputedSparseFunction, DefaultDimension, Buffer,
CustomDimension)
from devito.arch.compiler import OneapiCompiler
from devito.data import LEFT, RIGHT
from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols,
Expand Down Expand Up @@ -1009,7 +1010,7 @@ def test_avoid_haloupdate_if_distr_but_sequential(self, mode):
assert len(calls) == 0

@pytest.mark.parallel(mode=1)
def test_avoid_haloupdate_with_subdims(self, mode):
def test_avoid_haloupdate_with_local_subdims(self, mode):
grid = Grid(shape=(4,))
x = grid.dimensions[0]
t = grid.stepping_dim
Expand All @@ -1034,6 +1035,22 @@ def test_avoid_haloupdate_with_subdims(self, mode):
calls = FindNodes(Call).visit(op)
assert len(calls) == 1

@pytest.mark.parallel(mode=1)
def test_avoid_haloupdate_with_local_customdim(self, mode):
grid = Grid(shape=(10, 10))
x, y = grid.dimensions

d = CustomDimension(name='d', symbolic_min=1, symbolic_max=3, parent=y)

u = TimeFunction(name='u', grid=grid, space_order=4)

eq = Eq(u.forward.subs(y, -d), u.subs(y, d - 1) + 1)

op = Operator(eq)

calls = FindNodes(Call).visit(op)
assert len(calls) == 0

@pytest.mark.parallel(mode=1)
def test_avoid_haloupdate_with_constant_index(self, mode):
grid = Grid(shape=(4,))
Expand Down

0 comments on commit 3351867

Please sign in to comment.