Skip to content

Commit

Permalink
compiler: Fix detection of global distributed reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed May 29, 2024
1 parent 1f16737 commit 0512906
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 13 deletions.
21 changes: 16 additions & 5 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,14 +447,14 @@ def reduction_comms(clusters):
processed = []
fifo = []
for c in clusters:
# Schedule the global reductions encountered before `c`, if the
# IterationSpace of `c` is such that the reduction can be carried out
# Schedule the global distributed reductions encountered before `c`,
# if `c`'s IterationSpace is such that the reduction can be carried out
found, fifo = split(fifo, lambda dr: dr.ispace.is_subset(c.ispace))
if found:
exprs = [Eq(dr.var, dr) for dr in found]
processed.append(c.rebuild(exprs=exprs))

# Detect the global reductions in `c`
# Detect the global distributed reductions in `c`
for e in c.exprs:
op = e.operation
if op is None or c.is_sparse:
Expand All @@ -465,12 +465,23 @@ def reduction_comms(clusters):
if grid is None:
continue

# The IterationSpace within which the global reduction is carried out
# Is Inc/Max/Min/... actually used for a reduction?
ispace = c.ispace.project(lambda d: d in var.free_symbols)
if ispace.itdims == c.ispace.itdims:
# Inc/Max/Min/... being used for a non-reduction operation
continue

# The reduced Dimensions
rdims = set(c.ispace.itdims) - set(ispace.itdims)

# The reduced Dimensions inducing a global distributed reduction
grdims = {d for d in rdims if d._defines & c.dist_dimensions}
if not grdims:
continue

# The IterationSpace within which the global distributed reduction
# must be carried out
ispace = c.ispace.prefix(lambda d: d in var.free_symbols)

fifo.append(DistReduce(var, op=op, grid=grid, ispace=ispace))

processed.append(c)
Expand Down
13 changes: 13 additions & 0 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,19 @@ def used_dimensions(self):
idims = set.union(*[set(e.implicit_dims) for e in self.exprs])
return {i for i in self.free_symbols if i.is_Dimension} | idims

@cached_property
def dist_dimensions(self):
"""
The Cluster's distributed Dimensions.
"""
ret = set()
for f in self.functions:
try:
ret.update(f._dist_dimensions)
except AttributeError:
pass
return frozenset(ret)

@cached_property
def scope(self):
return Scope(self.exprs)
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/support/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ def prefix(self, key):
try:
i = self.project(key)[-1]
except IndexError:
return None
return null_ispace

return self[:self.index(i.dim) + 1]

Expand Down
13 changes: 6 additions & 7 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@

from conftest import _R, assert_blocking, assert_structure
from devito import (Grid, Constant, Function, TimeFunction, SparseFunction,
SparseTimeFunction, Dimension, ConditionalDimension, SubDimension,
SubDomain, Eq, Ne, Inc, NODE, Operator, norm, inner, configuration,
switchconfig, generic_derivative, PrecomputedSparseFunction,
DefaultDimension)
SparseTimeFunction, Dimension, ConditionalDimension,
SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm,
inner, configuration, switchconfig, generic_derivative,
PrecomputedSparseFunction, DefaultDimension)
from devito.arch.compiler import OneapiCompiler
from devito.data import LEFT, RIGHT
from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols,
retrieve_iteration_tree)
from devito.mpi import MPI
from devito.mpi.routines import (HaloUpdateCall, HaloUpdateList, MPICall,
ComputeCall, AllreduceCall)
ComputeCall)
from devito.mpi.distributed import CustomTopology
from devito.tools import Bunch

Expand Down Expand Up @@ -929,8 +929,7 @@ def test_avoid_haloupdate_as_nostencil_advanced(self, mode):

# No stencil in the expressions, so no halo update required!
calls = FindNodes(Call).visit(op)
assert len(calls) == 2
assert all(isinstance(i, AllreduceCall) for i in calls)
assert len(calls) == 0

@pytest.mark.parallel(mode=1)
def test_avoid_redundant_haloupdate(self, mode):
Expand Down

0 comments on commit 0512906

Please sign in to comment.