From 1f16737feea5f36492f123c65ff231fe009555b4 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 2 Apr 2024 15:00:48 +0000 Subject: [PATCH] compiler: Add support for C-level MPI_Allreduce --- devito/builtins/arithmetic.py | 60 +++++++++++++------------ devito/builtins/utils.py | 52 ++++++---------------- devito/core/gpu.py | 19 ++++---- devito/ir/clusters/algorithms.py | 75 ++++++++++++-------------------- devito/ir/clusters/cluster.py | 14 +++--- devito/mpi/reduction_scheme.py | 20 +++++---- devito/mpi/routines.py | 7 +-- devito/operator/operator.py | 21 +++++---- devito/passes/iet/langbase.py | 1 - devito/passes/iet/mpi.py | 43 ++++++++++-------- tests/test_builtins.py | 14 +++++- tests/test_mpi.py | 6 ++- 12 files changed, 161 insertions(+), 171 deletions(-) diff --git a/devito/builtins/arithmetic.py b/devito/builtins/arithmetic.py index bb0e31806d..f24a0e56ea 100644 --- a/devito/builtins/arithmetic.py +++ b/devito/builtins/arithmetic.py @@ -1,7 +1,7 @@ import numpy as np import devito as dv -from devito.builtins.utils import MPIReduction +from devito.builtins.utils import make_retval __all__ = ['norm', 'sumall', 'sum', 'inner', 'mmin', 'mmax'] @@ -44,15 +44,15 @@ def norm(f, order=2): p, eqns = f.guard() if f.is_SparseFunction else (f, []) dtype = accumulator_mapper[f.dtype] + n = make_retval(f.grid, dtype) s = dv.types.Symbol(name='sum', dtype=dtype) - with MPIReduction(f, dtype=dtype) as mr: - op = dv.Operator([dv.Eq(s, 0.0)] + eqns + - [dv.Inc(s, dv.Abs(Pow(p, order))), dv.Eq(mr.n[0], s)], - name='norm%d' % order) - op.apply(**kwargs) + op = dv.Operator([dv.Eq(s, 0.0)] + eqns + + [dv.Inc(s, dv.Abs(Pow(p, order))), dv.Eq(n[0], s)], + name='norm%d' % order) + op.apply(**kwargs) - v = np.power(mr.v, 1/order) + v = np.power(n.data[0], 1/order) return f.dtype(v) @@ -129,15 +129,15 @@ def sumall(f): p, eqns = f.guard() if f.is_SparseFunction else (f, []) dtype = accumulator_mapper[f.dtype] + n = make_retval(f.grid, dtype) s = dv.types.Symbol(name='sum', dtype=dtype) - with MPIReduction(f, dtype=dtype) as mr: - op = dv.Operator([dv.Eq(s, 0.0)] + eqns + - [dv.Inc(s, p), dv.Eq(mr.n[0], s)], - name='sum') - op.apply(**kwargs) + op = dv.Operator([dv.Eq(s, 0.0)] + eqns + + [dv.Inc(s, p), dv.Eq(n[0], s)], + name='sum') + op.apply(**kwargs) - return f.dtype(mr.v) + return f.dtype(n.data[0]) @dv.switchconfig(log_level='ERROR') @@ -184,15 +184,15 @@ def inner(f, g): rhs, eqns = f.guard(f*g) if f.is_SparseFunction else (f*g, []) dtype = accumulator_mapper[f.dtype] + n = make_retval(f.grid or g.grid, dtype) s = dv.types.Symbol(name='sum', dtype=dtype) - with MPIReduction(f, g, dtype=dtype) as mr: - op = dv.Operator([dv.Eq(s, 0.0)] + eqns + - [dv.Inc(s, rhs), dv.Eq(mr.n[0], s)], - name='inner') - op.apply(**kwargs) + op = dv.Operator([dv.Eq(s, 0.0)] + eqns + + [dv.Inc(s, rhs), dv.Eq(n[0], s)], + name='inner') + op.apply(**kwargs) - return f.dtype(mr.v) + return f.dtype(n.data[0]) @dv.switchconfig(log_level='ERROR') @@ -208,11 +208,14 @@ def mmin(f): if isinstance(f, dv.Constant): return f.data elif isinstance(f, dv.types.dense.DiscreteFunction): - with MPIReduction(f, op=dv.mpi.MPI.MIN) as mr: - mr.n.data[0] = np.min(f.data_ro_domain).item() - return mr.v.item() + v = np.min(f.data_ro_domain) + if f.grid is None or not dv.configuration['mpi']: + return v.item() + else: + comm = f.grid.distributor.comm + return comm.allreduce(v, dv.mpi.MPI.MIN).item() else: - raise ValueError("Expected Function, not `%s`" % type(f)) + raise ValueError("Expected Function, got `%s`" % type(f)) @dv.switchconfig(log_level='ERROR') @@ -228,8 +231,11 @@ def mmax(f): if isinstance(f, dv.Constant): return f.data elif isinstance(f, dv.types.dense.DiscreteFunction): - with MPIReduction(f, op=dv.mpi.MPI.MAX) as mr: - mr.n.data[0] = np.max(f.data_ro_domain).item() - return mr.v.item() + v = np.max(f.data_ro_domain) + if f.grid is None or not dv.configuration['mpi']: + return v.item() + else: + comm = f.grid.distributor.comm + return comm.allreduce(v, dv.mpi.MPI.MAX).item() else: - raise ValueError("Expected Function, not `%s`" % type(f)) + raise ValueError("Expected Function, got `%s`" % type(f)) diff --git a/devito/builtins/utils.py b/devito/builtins/utils.py index fe5e0cdb9d..786dbbce48 100644 --- a/devito/builtins/utils.py +++ b/devito/builtins/utils.py @@ -1,52 +1,26 @@ from functools import wraps -import numpy as np - import devito as dv from devito.symbolics import uxreplace from devito.tools import as_tuple -__all__ = ['MPIReduction', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args'] +__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args'] -class MPIReduction: +def make_retval(grid, dtype): """ - A context manager to build MPI-aware reduction Operators. + Devito does not support passing values by reference. This function + creates a dummy Function of size 1 to store the return value of a builtin + applied to `f`. """ - - def __init__(self, *functions, op=dv.mpi.MPI.SUM, dtype=None): - grids = {f.grid for f in functions} - if len(grids) == 0: - self.grid = None - elif len(grids) == 1: - self.grid = grids.pop() - else: - raise ValueError("Multiple Grids found") - if dtype is not None: - self.dtype = dtype - else: - dtype = {f.dtype for f in functions} - if len(dtype) == 1: - self.dtype = np.result_type(dtype.pop(), np.float32).type - else: - raise ValueError("Illegal mixed data types") - self.v = None - self.op = op - - def __enter__(self): - i = dv.Dimension(name='mri',) - self.n = dv.Function(name='n', shape=(1,), dimensions=(i,), - grid=self.grid, dtype=self.dtype, space='host') - self.n.data[:] = 0 - return self - - def __exit__(self, exc_type, exc_value, traceback): - if self.grid is None or not dv.configuration['mpi']: - assert self.n.data.size == 1 - self.v = self.n.data[0] - else: - comm = self.grid.distributor.comm - self.v = comm.allreduce(np.asarray(self.n.data), self.op)[0] + if grid is None: + raise ValueError("Expected Grid, got None") + + i = dv.Dimension(name='mri',) + n = dv.Function(name='n', shape=(1,), dimensions=(i,), grid=grid, + dtype=dtype, space='host') + n.data[:] = 0 + return n def nbl_to_padsize(nbl, ndim): diff --git a/devito/core/gpu.py b/devito/core/gpu.py index 8d7ea75195..266c198647 100644 --- a/devito/core/gpu.py +++ b/devito/core/gpu.py @@ -116,19 +116,20 @@ def _normalize_gpu_fit(cls, oo, **kwargs): return as_tuple(cls.GPU_FIT) @classmethod - def _rcompile_wrapper(cls, **kwargs0): - options = kwargs0['options'] + def _rcompile_wrapper(cls, **kwargs): + def wrapper(expressions, mode='default', **options): - def wrapper(expressions, mode='default', **kwargs1): if mode == 'host': - kwargs = {**{ + par_disabled = kwargs['options']['par-disabled'] + target = { 'platform': 'cpu64', - 'language': 'C' if options['par-disabled'] else 'openmp', - 'compiler': 'custom', - }, **kwargs1} + 'language': 'C' if par_disabled else 'openmp', + 'compiler': 'custom' + } else: - kwargs = {**kwargs0, **kwargs1} - return rcompile(expressions, kwargs) + target = None + + return rcompile(expressions, kwargs, options, target=target) return wrapper diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index bce705b40a..b2b552530d 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -14,11 +14,11 @@ from devito.ir.clusters.cluster import Cluster, ClusterGroup from devito.ir.clusters.visitors import Queue, QueueStateful, cluster_pass from devito.mpi.halo_scheme import HaloScheme, HaloTouch -from devito.mpi.reduction_scheme import DistributedReduction +from devito.mpi.reduction_scheme import DistReduce from devito.symbolics import (limits_mapper, retrieve_indexed, uxreplace, xreplace_indices) from devito.tools import (DefaultOrderedDict, Stamp, as_mapper, flatten, - is_integer, timed_pass, toposort) + is_integer, split, timed_pass, toposort) from devito.types import Array, Eq, Symbol from devito.types.dimension import BOTTOM, ModuloDimension @@ -378,12 +378,10 @@ def communications(clusters): return clusters -class Comms(Queue): - #TODO: MAYBE DROP ME +class HaloComms(Queue): """ - Abstract base class for injecting Clusters representing communications - for distributed-memory parallelism. + Inject Clusters representing halo exchanges for distributed-memory parallelism. """ _q_guards_in_key = True @@ -391,13 +389,6 @@ class Comms(Queue): B = Symbol(name='⊥') - -class HaloComms(Comms): - - """ - A specialization of Comms to handle halo exchanges. - """ - def process(self, clusters): return self._process_fatd(clusters, 1, seen=set()) @@ -453,49 +444,41 @@ def callback(self, clusters, prefix, seen=None): def reduction_comms(clusters): - # Detect the underlying Grid - #TODO: pretty rudimentary, but it's a start - for c in clusters: - try: - grid = c.grid - break - except ValueError: - continue - else: - return clusters - - # Detect global reductions along the distributed Dimensions - found = {} + processed = [] + fifo = [] for c in clusters: - if not any(grid.is_distributed(d) for d in c.ispace.itdims): - continue - + # Schedule the global reductions encountered before `c`, if the + # IterationSpace of `c` is such that the reduction can be carried out + found, fifo = split(fifo, lambda dr: dr.ispace.is_subset(c.ispace)) + if found: + exprs = [Eq(dr.var, dr) for dr in found] + processed.append(c.rebuild(exprs=exprs)) + + # Detect the global reductions in `c` for e in c.exprs: op = e.operation - if op is None: + if op is None or c.is_sparse: continue - elif found.get(e.lhs, op) != op: - raise ValueError("Inconsistent reduction operations") - else: - found[e.lhs] = e.operation - # Place global reductions right before they're required - processed = [] - for c in clusters: - for var, op in list(found.items()): - if var in c.scope.read_only: - expr = Eq(var, DistributedReduction(var, op=op, grid=grid)) - processed.append(c.rebuild(exprs=expr)) + var = e.lhs + grid = c.grid + if grid is None: + continue + + # The IterationSpace within which the global reduction is carried out + ispace = c.ispace.project(lambda d: d in var.free_symbols) + if ispace.itdims == c.ispace.itdims: + # Inc/Max/Min/... being used for a non-reduction operation + continue - found.pop(var) + fifo.append(DistReduce(var, op=op, grid=grid, ispace=ispace)) processed.append(c) # Leftover reductions are placed at the very end - while found: - var, op = found.popitem() - expr = Eq(var, DistributedReduction(var, op=op, grid=grid)) - processed.append(Cluster(exprs=[expr], ispace=null_ispace)) + if fifo: + exprs = [Eq(dr.var, dr) for dr in fifo] + processed.append(Cluster(exprs=exprs, ispace=null_ispace)) return processed diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 5e56be055a..a429e4714f 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -11,7 +11,7 @@ normalize_properties, normalize_syncs, minimum, maximum, null_ispace) from devito.mpi.halo_scheme import HaloScheme, HaloTouch -from devito.mpi.reduction_scheme import DistributedReduction +from devito.mpi.reduction_scheme import DistReduce from devito.symbolics import estimate_cost from devito.tools import as_tuple, flatten, frozendict, infer_dtype from devito.types import WeakFence, CriticalRegion @@ -181,8 +181,11 @@ def has_increments(self): @cached_property def grid(self): - grids = set(f.grid for f in self.functions if f.is_DiscreteFunction) - {None} - if len(grids) == 1: + grids = set(f.grid for f in self.functions if f.is_AbstractFunction) + grids.discard(None) + if len(grids) == 0: + return None + elif len(grids) == 1: return grids.pop() else: raise ValueError("Cluster has no unique Grid") @@ -211,7 +214,7 @@ def is_dense(self): dims = {d for d in self.properties if d._defines & target} if any(pset & self.properties[d] for d in dims): return True - except ValueError: + except (AttributeError, ValueError): pass # Fallback to legacy is_dense checks @@ -241,8 +244,7 @@ def is_halo_touch(self): @property def is_dist_reduce(self): - return self.exprs and all(isinstance(e.rhs, DistributedReduction) - for e in self.exprs) + return self.exprs and all(isinstance(e.rhs, DistReduce) for e in self.exprs) @property def is_fence(self): diff --git a/devito/mpi/reduction_scheme.py b/devito/mpi/reduction_scheme.py index bf2ccb0bbd..f3a412f07d 100644 --- a/devito/mpi/reduction_scheme.py +++ b/devito/mpi/reduction_scheme.py @@ -2,26 +2,27 @@ from devito.tools import Reconstructable -__all__ = ['DistributedReduction'] +__all__ = ['DistReduce'] -class DistributedReduction(sympy.Function, Reconstructable): +class DistReduce(sympy.Function, Reconstructable): """ A SymPy object representing a distributed Reduction. """ __rargs__ = ('var',) - __rkwargs__ = ('op', 'grid') + __rkwargs__ = ('op', 'grid', 'ispace') - def __new__(cls, var, op=None, grid=None, **kwargs): + def __new__(cls, var, op=None, grid=None, ispace=None, **kwargs): obj = sympy.Function.__new__(cls, var, **kwargs) obj.op = op obj.grid = grid + obj.ispace = ispace return obj def __repr__(self): - return "DistributedReduction(%s,%s)" % (self.var, self.op) + return "DistReduce(%s,%s)" % (self.var, self.op) __str__ = __repr__ @@ -29,13 +30,16 @@ def _sympystr(self, printer): return str(self) def _hashable_content(self): - return (self.op, self.grid) + return (self.op, self.grid, self.ispace) def __eq__(self, other): - return (isinstance(other, DistributedReduction) and + return (isinstance(other, DistReduce) and self.var == other.var and self.op == other.op and - self.grid == other.grid) + self.grid == other.grid and + self.ispace == other.ispace) + + __hash__ = sympy.Function.__hash__ func = Reconstructable._rebuild diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index 2968a2d4fb..46fc2a8e3a 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -21,7 +21,7 @@ from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject, CompositeObject, CustomDimension) -__all__ = ['HaloExchangeBuilder', 'ReductionBuilder, ''mpi_registry'] +__all__ = ['HaloExchangeBuilder', 'ReductionBuilder', 'mpi_registry'] class HaloExchangeBuilder: @@ -30,7 +30,8 @@ class HaloExchangeBuilder: Build IET routines to generate MPI halo exchanges. """ - def __new__(cls, mpimode, generators=None, rcompile=None, sregistry=None, **kwargs): + def __new__(cls, mpimode, generators=None, rcompile=None, sregistry=None, + **kwargs): obj = object.__new__(mpi_registry[mpimode]) obj.rcompile = rcompile @@ -370,7 +371,7 @@ def _make_copy(self, f, hse, key, swap=False): eqns.append(Eq(*swap(buf[[i] + bdims], f[[i] + findices]))) # Compile `eqns` into an IET via recursive compilation - irs, _ = self.rcompile(eqns) + irs, _ = self.rcompile(eqns, mpi=False) parameters = [buf] + bshape + list(f.handles) + ofs diff --git a/devito/operator/operator.py b/devito/operator/operator.py index d1d2daeb5c..5c781512b9 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -271,9 +271,9 @@ def _lower(cls, expressions, **kwargs): return IRs(expressions, clusters, stree, uiet, iet), byproduct @classmethod - def _rcompile_wrapper(cls, **kwargs0): - def wrapper(expressions, **kwargs1): - return rcompile(expressions, {**kwargs0, **kwargs1}) + def _rcompile_wrapper(cls, **kwargs): + def wrapper(expressions, **options): + return rcompile(expressions, kwargs, options) return wrapper @classmethod @@ -1049,26 +1049,25 @@ def __setstate__(self, state): # if applied in cascade (e.g., `linearization` on top of `linearization`) rcompile_registry = { 'avoid_denormals': False, - #'mpi': False, #TODO: DROP / DON'T DROP?? NEED IT FOR GLB REDUCTIONS... 'linearize': False, 'place-transfers': False } -def rcompile(expressions, kwargs=None): +def rcompile(expressions, kwargs, options, target=None): """ Perform recursive compilation on an ordered sequence of symbolic expressions. """ - if not kwargs or 'options' not in kwargs: - kwargs = parse_kwargs(**kwargs) + options = {**kwargs['options'], **rcompile_registry, **options} + + if target is None: cls = operator_selector(**kwargs) - kwargs = cls._normalize_kwargs(**kwargs) else: + kwargs = parse_kwargs(**target) cls = operator_selector(**kwargs) + kwargs = cls._normalize_kwargs(**kwargs) - # Tweak the compilation kwargs - options = dict(kwargs['options']) - options.update(rcompile_registry) + # Use the customized opt options kwargs['options'] = options # Recursive profiling not supported -- would be a complete mess diff --git a/devito/passes/iet/langbase.py b/devito/passes/iet/langbase.py index 0331760be8..d27674c419 100644 --- a/devito/passes/iet/langbase.py +++ b/devito/passes/iet/langbase.py @@ -432,7 +432,6 @@ def _(iet): break except AttributeError: pass - assert objcomm is not None devicetype = as_list(self.lang[self.platform]) deviceid = self.deviceid diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 1a1e4d122a..3a714e095a 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -7,7 +7,7 @@ retrieve_iteration_tree) from devito.ir.support import PARALLEL, Scope from devito.mpi.halo_scheme import HaloScheme -from devito.mpi.reduction_scheme import DistributedReduction +from devito.mpi.reduction_scheme import DistReduce from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder from devito.passes.iet.engine import iet_pass from devito.tools import generator @@ -297,15 +297,13 @@ def _mark_overlappable(iet): @iet_pass -def make_mpi(iet, mpimode=None, **kwargs): +def make_halo_exchanges(iet, mpimode=None, **kwargs): """ - Inject MPI Callables and Calls implementing halo exchanges and reductions for - distributed-memory parallelism. + Lower HaloSpots into halo exchanges for distributed-memory parallelism. """ # To produce unique object names generators = {'msg': generator(), 'comm': generator(), 'comp': generator()} - # Halo exchanges sync_heb = HaloExchangeBuilder('basic', generators, **kwargs) user_heb = HaloExchangeBuilder(mpimode, generators, **kwargs) mapper = {} @@ -330,26 +328,33 @@ def make_mpi(iet, mpimode=None, **kwargs): break iet = Transformer(mapper, nested=True).visit(iet) - # Reductions + return iet, {'includes': ['mpi.h'], 'efuncs': efuncs} + + +@iet_pass +def make_reductions(iet, mpimode=None, **kwargs): rb = ReductionBuilder() - mapper = {e: rb.make(e.expr.rhs) for e in FindNodes(Expression).visit(iet) - if isinstance(e.expr.rhs, DistributedReduction)} + + mapper = {} + for e in FindNodes(Expression).visit(iet): + if not isinstance(e.expr.rhs, DistReduce): + continue + elif mpimode: + mapper[e] = rb.make(e.expr.rhs) + else: + mapper[e] = None iet = Transformer(mapper, nested=True).visit(iet) - return iet, {'includes': ['mpi.h'], 'efuncs': efuncs} + return iet, {} def mpiize(graph, **kwargs): """ - Perform two IET passes: + Perform three IET passes: - * Optimization of communications - * Injection of MPI code - - The former is implemented by manipulating HaloSpots. - - The latter resorts to creating MPI Callables and replacing HaloSpots with Calls - to MPI Callables. + * Optimization of halo exchanges + * Injection of code for halo exchanges + * Injection of code for reductions """ options = kwargs['options'] @@ -358,4 +363,6 @@ def mpiize(graph, **kwargs): mpimode = options['mpi'] if mpimode: - make_mpi(graph, mpimode=mpimode, **kwargs) + make_halo_exchanges(graph, mpimode=mpimode, **kwargs) + + make_reductions(graph, mpimode=mpimode, **kwargs) diff --git a/tests/test_builtins.py b/tests/test_builtins.py index 21b4ca0830..32e60c0912 100644 --- a/tests/test_builtins.py +++ b/tests/test_builtins.py @@ -448,7 +448,7 @@ def test_sum_sparse(self): def test_min_max_sparse(self): """ - Test that mmin/mmax work on SparseFunction + Test that mmin/mmax work on SparseFunction. """ grid = Grid((101, 101), extent=(1000., 1000.)) @@ -464,6 +464,18 @@ def test_min_max_sparse(self): term2 = mmax(rec0) assert np.isclose(term1/term2 - 1, 0.0, rtol=0.0, atol=1e-5) + @pytest.mark.parallel(mode=4) + def test_min_max_mpi(self): + grid = Grid(shape=(100, 100)) + + f = Function(name='f', grid=grid) + + # Populate data with increasing values starting at 1 + f.data[:] = np.arange(1, 10001).reshape((100, 100)) + + assert mmin(f) == 1 + assert mmax(f) == 10000 + def test_issue_1860(self): grid = Grid(shape=(401, 301, 181)) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 2bb2ed04d7..d7bd5188d5 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -13,7 +13,8 @@ from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols, retrieve_iteration_tree) from devito.mpi import MPI -from devito.mpi.routines import HaloUpdateCall, HaloUpdateList, MPICall, ComputeCall +from devito.mpi.routines import (HaloUpdateCall, HaloUpdateList, MPICall, + ComputeCall, AllreduceCall) from devito.mpi.distributed import CustomTopology from devito.tools import Bunch @@ -928,7 +929,8 @@ def test_avoid_haloupdate_as_nostencil_advanced(self, mode): # No stencil in the expressions, so no halo update required! calls = FindNodes(Call).visit(op) - assert len(calls) == 0 + assert len(calls) == 2 + assert all(isinstance(i, AllreduceCall) for i in calls) @pytest.mark.parallel(mode=1) def test_avoid_redundant_haloupdate(self, mode):