diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index e879f8aff1..e62a38c5e6 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -435,23 +435,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} - # Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and - # `ofs` are symbolic objects. This mapper tells what data values should be - # sent (OWNED) or received (HALO) given dimension and side - mapper = {} - for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)): - if d0 in fixed: - continue - sizes = [] - ofs = [] - for d1 in f.dimensions: - if d1 in fixed: - ofs.append(fixed[d1]) - else: - meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side) - ofs.append(meta.offset) - sizes.append(meta.size) - mapper[(d0, side, region)] = (sizes, ofs) + mapper = self._make_basic_mapper(f, fixed) body = [] for d in f.dimensions: @@ -483,6 +467,27 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): return HaloUpdate('haloupdate%s' % key, iet, parameters) + def _make_basic_mapper(self, f, fixed): + # Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and + # `ofs` are symbolic objects. This mapper tells what data values should be + # sent (OWNED) or received (HALO) given dimension and side + mapper = {} + for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)): + if d0 in fixed: + continue + sizes = [] + ofs = [] + for d1 in f.dimensions: + if d1 in fixed: + ofs.append(fixed[d1]) + else: + meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side) + ofs.append(meta.offset) + sizes.append(meta.size) + mapper[(d0, side, region)] = (sizes, ofs) + + return mapper + def _call_haloupdate(self, name, f, hse, *args): comm = f.grid.distributor._obj_comm nb = f.grid.distributor._obj_neighborhood @@ -526,6 +531,125 @@ def _make_body(self, callcompute, remainder, haloupdates, halowaits): return List(body=body) +class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder): + + """ + A BasicHaloExchangeBuilder making use of pre-allocated buffers for + message size. + + Generates: + + haloupdate() + compute() + """ + + def _make_msg(self, f, hse, key): + # Pass the fixed mapper e.g. {t: otime} + fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} + + return MPIMsgBasic2('msg%d' % key, f, hse.halos, fixed) + + def _make_sendrecv(self, f, hse, key, msg=None): + cast = cast_mapper[(f.c0.dtype, '*')] + comm = f.grid.distributor._obj_comm + + bufg = FieldFromPointer(msg._C_field_bufg, msg) + bufs = FieldFromPointer(msg._C_field_bufs, msg) + + ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions] + ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions] + + fromrank = Symbol(name='fromrank') + torank = Symbol(name='torank') + + sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg) + for i in range(len(f._dist_dimensions))] + + arguments = [cast(bufg)] + sizes + list(f.handles) + ofsg + gather = Gather('gather%s' % key, arguments) + # The `gather` is unnecessary if sending to MPI.PROC_NULL + gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather) + + arguments = [cast(bufs)] + sizes + list(f.handles) + ofss + scatter = Scatter('scatter%s' % key, arguments) + # The `scatter` must be guarded as we must not alter the halo values along + # the domain boundary, where the sender is actually MPI.PROC_NULL + scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) + + count = reduce(mul, sizes, 1)*dtype_len(f.dtype) + rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg)) + rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg)) + recv = IrecvCall([bufs, count, Macro(dtype_to_mpitype(f.dtype)), + fromrank, Integer(13), comm, rrecv]) + send = IsendCall([bufg, count, Macro(dtype_to_mpitype(f.dtype)), + torank, Integer(13), comm, rsend]) + + waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')]) + waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')]) + + iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter]) + + parameters = (list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg]) + + return SendRecv('sendrecv%s' % key, iet, parameters, bufg, bufs) + + def _call_sendrecv(self, name, *args, msg=None, haloid=None): + # Drop `sizes` as this HaloExchangeBuilder conveys them through `msg` + f, _, ofsg, ofss, fromrank, torank, comm = args + msg = Byref(IndexedPointer(msg, haloid)) + return Call(name, list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg]) + + def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): + distributor = f.grid.distributor + nb = distributor._obj_neighborhood + comm = distributor._obj_comm + + fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} + + mapper = self._make_basic_mapper(f, fixed) + + body = [] + for d in f.dimensions: + if d in fixed: + continue + + name = ''.join('r' if i is d else 'c' for i in distributor.dimensions) + rpeer = FieldFromPointer(name, nb) + name = ''.join('l' if i is d else 'c' for i in distributor.dimensions) + lpeer = FieldFromPointer(name, nb) + + if (d, LEFT) in hse.halos: + # Sending to left, receiving from right + lsizes, lofs = mapper[(d, LEFT, OWNED)] + rsizes, rofs = mapper[(d, RIGHT, HALO)] + args = [f, lsizes, lofs, rofs, rpeer, lpeer, comm] + kwargs['haloid'] = len(body) + body.append(self._call_sendrecv(sendrecv.name, *args, **kwargs)) + + if (d, RIGHT) in hse.halos: + # Sending to right, receiving from left + rsizes, rofs = mapper[(d, RIGHT, OWNED)] + lsizes, lofs = mapper[(d, LEFT, HALO)] + args = [f, rsizes, rofs, lofs, lpeer, rpeer, comm] + kwargs['haloid'] = len(body) + body.append(self._call_sendrecv(sendrecv.name, *args, **kwargs)) + + iet = List(body=body) + + parameters = list(f.handles) + [comm, nb] + list(fixed.values()) + + node = HaloUpdate('haloupdate%s' % key, iet, parameters) + + node = node._rebuild(parameters=node.parameters + (kwargs['msg'],)) + + return node + + def _call_haloupdate(self, name, f, hse, msg): + call = super()._call_haloupdate(name, f, hse) + call = call._rebuild(arguments=call.arguments + (msg,)) + return call + + class DiagHaloExchangeBuilder(BasicHaloExchangeBuilder): """ @@ -741,141 +865,6 @@ def _call_remainder(self, remainder): return call -class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder): - - """ - A BasicHaloExchangeBuilder making use of pre-allocated buffers for - message size. - - Generates: - - haloupdate() - compute() - """ - - def _make_msg(self, f, hse, key): - # Pass the fixed mapper e.g. {t: otime} - fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} - - return MPIMsgBasic2('msg%d' % key, f, hse.halos, fixed) - - def _make_sendrecv(self, f, hse, key, msg=None): - cast = cast_mapper[(f.c0.dtype, '*')] - comm = f.grid.distributor._obj_comm - - bufg = FieldFromPointer(msg._C_field_bufg, msg) - bufs = FieldFromPointer(msg._C_field_bufs, msg) - - ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions] - ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions] - - fromrank = Symbol(name='fromrank') - torank = Symbol(name='torank') - - sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg) - for i in range(len(f._dist_dimensions))] - - arguments = [cast(bufg)] + sizes + list(f.handles) + ofsg - gather = Gather('gather%s' % key, arguments) - # The `gather` is unnecessary if sending to MPI.PROC_NULL - gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather) - - arguments = [cast(bufs)] + sizes + list(f.handles) + ofss - scatter = Scatter('scatter%s' % key, arguments) - # The `scatter` must be guarded as we must not alter the halo values along - # the domain boundary, where the sender is actually MPI.PROC_NULL - scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) - - count = reduce(mul, sizes, 1)*dtype_len(f.dtype) - rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg)) - rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg)) - recv = IrecvCall([bufs, count, Macro(dtype_to_mpitype(f.dtype)), - fromrank, Integer(13), comm, rrecv]) - send = IsendCall([bufg, count, Macro(dtype_to_mpitype(f.dtype)), - torank, Integer(13), comm, rsend]) - - waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')]) - waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')]) - - iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter]) - - parameters = (list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg]) - - return SendRecv('sendrecv%s' % key, iet, parameters, bufg, bufs) - - def _call_sendrecv(self, name, *args, msg=None, haloid=None): - # Drop `sizes` as this HaloExchangeBuilder conveys them through `msg` - f, _, ofsg, ofss, fromrank, torank, comm = args - msg = Byref(IndexedPointer(msg, haloid)) - return Call(name, list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg]) - - def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): - distributor = f.grid.distributor - nb = distributor._obj_neighborhood - comm = distributor._obj_comm - - fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} - - # Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and - # `ofs` are symbolic objects. This mapper tells what data values should be - # sent (OWNED) or received (HALO) given dimension and side - mapper = {} - for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)): - if d0 in fixed: - continue - sizes = [] - ofs = [] - for d1 in f.dimensions: - if d1 in fixed: - ofs.append(fixed[d1]) - else: - meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side) - ofs.append(meta.offset) - sizes.append(meta.size) - mapper[(d0, side, region)] = (sizes, ofs) - - body = [] - for d in f.dimensions: - if d in fixed: - continue - - name = ''.join('r' if i is d else 'c' for i in distributor.dimensions) - rpeer = FieldFromPointer(name, nb) - name = ''.join('l' if i is d else 'c' for i in distributor.dimensions) - lpeer = FieldFromPointer(name, nb) - - if (d, LEFT) in hse.halos: - # Sending to left, receiving from right - lsizes, lofs = mapper[(d, LEFT, OWNED)] - rsizes, rofs = mapper[(d, RIGHT, HALO)] - args = [f, lsizes, lofs, rofs, rpeer, lpeer, comm] - kwargs['haloid'] = len(body) - body.append(self._call_sendrecv(sendrecv.name, *args, **kwargs)) - - if (d, RIGHT) in hse.halos: - # Sending to right, receiving from left - rsizes, rofs = mapper[(d, RIGHT, OWNED)] - lsizes, lofs = mapper[(d, LEFT, HALO)] - args = [f, rsizes, rofs, lofs, lpeer, rpeer, comm] - kwargs['haloid'] = len(body) - body.append(self._call_sendrecv(sendrecv.name, *args, **kwargs)) - - iet = List(body=body) - - parameters = list(f.handles) + [comm, nb] + list(fixed.values()) - - node = HaloUpdate('haloupdate%s' % key, iet, parameters) - - node = node._rebuild(parameters=node.parameters + (kwargs['msg'],)) - - return node - - def _call_haloupdate(self, name, f, hse, msg): - call = super()._call_haloupdate(name, f, hse) - call = call._rebuild(arguments=call.arguments + (msg,)) - return call - - class Overlap2HaloExchangeBuilder(OverlapHaloExchangeBuilder): """ @@ -1425,9 +1414,7 @@ def _arg_defaults(self, allocator, alias, args=None): fixed = self._fixed - # Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and - # `ofs` are symbolic objects. This mapper tells what data values should be - # sent (OWNED) or received (HALO) given dimension and side + # Build a mapper `(dim, side, region) -> (size)` for `f`. mapper = {} for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)): if d0 in fixed: