Skip to content

Commit

Permalink
mpi: add make_basic_mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Feb 12, 2024
1 parent b4e67e3 commit befd078
Showing 1 changed file with 131 additions and 142 deletions.
273 changes: 131 additions & 142 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,7 @@ def _call_sendrecv(self, name, *args, **kwargs):
args = list(args[0].handles) + flatten(args[1:])
return Call(name, args)

def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
distributor = f.grid.distributor
nb = distributor._obj_neighborhood
comm = distributor._obj_comm

fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}

def _make_basic_mapper(self, f, fixed):
# Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
# `ofs` are symbolic objects. This mapper tells what data values should be
# sent (OWNED) or received (HALO) given dimension and side
Expand All @@ -453,6 +447,17 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
sizes.append(meta.size)
mapper[(d0, side, region)] = (sizes, ofs)

return mapper

def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
distributor = f.grid.distributor
nb = distributor._obj_neighborhood
comm = distributor._obj_comm

fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}

mapper = self._make_basic_mapper(f, fixed)

body = []
for d in f.dimensions:
if d in fixed:
Expand Down Expand Up @@ -526,6 +531,125 @@ def _make_body(self, callcompute, remainder, haloupdates, halowaits):
return List(body=body)


class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder):

"""
A BasicHaloExchangeBuilder making use of pre-allocated buffers for
message size.
Generates:
haloupdate()
compute()
"""

def _make_msg(self, f, hse, key):
# Pass the fixed mapper e.g. {t: otime}
fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}

return MPIMsgBasic2('msg%d' % key, f, hse.halos, fixed)

def _make_sendrecv(self, f, hse, key, msg=None):
cast = cast_mapper[(f.c0.dtype, '*')]
comm = f.grid.distributor._obj_comm

bufg = FieldFromPointer(msg._C_field_bufg, msg)
bufs = FieldFromPointer(msg._C_field_bufs, msg)

ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions]
ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]

fromrank = Symbol(name='fromrank')
torank = Symbol(name='torank')

sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg)
for i in range(len(f._dist_dimensions))]

arguments = [cast(bufg)] + sizes + list(f.handles) + ofsg
gather = Gather('gather%s' % key, arguments)
# The `gather` is unnecessary if sending to MPI.PROC_NULL
gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)

arguments = [cast(bufs)] + sizes + list(f.handles) + ofss
scatter = Scatter('scatter%s' % key, arguments)
# The `scatter` must be guarded as we must not alter the halo values along
# the domain boundary, where the sender is actually MPI.PROC_NULL
scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)

count = reduce(mul, sizes, 1)*dtype_len(f.dtype)
rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg))
rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg))
recv = IrecvCall([bufs, count, Macro(dtype_to_mpitype(f.dtype)),
fromrank, Integer(13), comm, rrecv])
send = IsendCall([bufg, count, Macro(dtype_to_mpitype(f.dtype)),
torank, Integer(13), comm, rsend])

waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])

iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter])

parameters = (list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg])

return SendRecv('sendrecv%s' % key, iet, parameters, bufg, bufs)

def _call_sendrecv(self, name, *args, msg=None, haloid=None):
# Drop `sizes` as this HaloExchangeBuilder conveys them through `msg`
f, _, ofsg, ofss, fromrank, torank, comm = args
msg = Byref(IndexedPointer(msg, haloid))
return Call(name, list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg])

def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
distributor = f.grid.distributor
nb = distributor._obj_neighborhood
comm = distributor._obj_comm

fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}

mapper = self._make_basic_mapper(f, fixed)

body = []
for d in f.dimensions:
if d in fixed:
continue

name = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
rpeer = FieldFromPointer(name, nb)
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
lpeer = FieldFromPointer(name, nb)

if (d, LEFT) in hse.halos:
# Sending to left, receiving from right
lsizes, lofs = mapper[(d, LEFT, OWNED)]
rsizes, rofs = mapper[(d, RIGHT, HALO)]
args = [f, lsizes, lofs, rofs, rpeer, lpeer, comm]
kwargs['haloid'] = len(body)
body.append(self._call_sendrecv(sendrecv.name, *args, **kwargs))

if (d, RIGHT) in hse.halos:
# Sending to right, receiving from left
rsizes, rofs = mapper[(d, RIGHT, OWNED)]
lsizes, lofs = mapper[(d, LEFT, HALO)]
args = [f, rsizes, rofs, lofs, lpeer, rpeer, comm]
kwargs['haloid'] = len(body)
body.append(self._call_sendrecv(sendrecv.name, *args, **kwargs))

iet = List(body=body)

parameters = list(f.handles) + [comm, nb] + list(fixed.values())

node = HaloUpdate('haloupdate%s' % key, iet, parameters)

node = node._rebuild(parameters=node.parameters + (kwargs['msg'],))

return node

def _call_haloupdate(self, name, f, hse, msg):
call = super()._call_haloupdate(name, f, hse)
call = call._rebuild(arguments=call.arguments + (msg,))
return call


class DiagHaloExchangeBuilder(BasicHaloExchangeBuilder):

"""
Expand Down Expand Up @@ -741,141 +865,6 @@ def _call_remainder(self, remainder):
return call


class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder):

"""
A BasicHaloExchangeBuilder making use of pre-allocated buffers for
message size.
Generates:
haloupdate()
compute()
"""

def _make_msg(self, f, hse, key):
# Pass the fixed mapper e.g. {t: otime}
fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}

return MPIMsgBasic2('msg%d' % key, f, hse.halos, fixed)

def _make_sendrecv(self, f, hse, key, msg=None):
cast = cast_mapper[(f.c0.dtype, '*')]
comm = f.grid.distributor._obj_comm

bufg = FieldFromPointer(msg._C_field_bufg, msg)
bufs = FieldFromPointer(msg._C_field_bufs, msg)

ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions]
ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]

fromrank = Symbol(name='fromrank')
torank = Symbol(name='torank')

sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg)
for i in range(len(f._dist_dimensions))]

arguments = [cast(bufg)] + sizes + list(f.handles) + ofsg
gather = Gather('gather%s' % key, arguments)
# The `gather` is unnecessary if sending to MPI.PROC_NULL
gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)

arguments = [cast(bufs)] + sizes + list(f.handles) + ofss
scatter = Scatter('scatter%s' % key, arguments)
# The `scatter` must be guarded as we must not alter the halo values along
# the domain boundary, where the sender is actually MPI.PROC_NULL
scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)

count = reduce(mul, sizes, 1)*dtype_len(f.dtype)
rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg))
rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg))
recv = IrecvCall([bufs, count, Macro(dtype_to_mpitype(f.dtype)),
fromrank, Integer(13), comm, rrecv])
send = IsendCall([bufg, count, Macro(dtype_to_mpitype(f.dtype)),
torank, Integer(13), comm, rsend])

waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])

iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter])

parameters = (list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg])

return SendRecv('sendrecv%s' % key, iet, parameters, bufg, bufs)

def _call_sendrecv(self, name, *args, msg=None, haloid=None):
# Drop `sizes` as this HaloExchangeBuilder conveys them through `msg`
f, _, ofsg, ofss, fromrank, torank, comm = args
msg = Byref(IndexedPointer(msg, haloid))
return Call(name, list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg])

def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
distributor = f.grid.distributor
nb = distributor._obj_neighborhood
comm = distributor._obj_comm

fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}

# Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
# `ofs` are symbolic objects. This mapper tells what data values should be
# sent (OWNED) or received (HALO) given dimension and side
mapper = {}
for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)):
if d0 in fixed:
continue
sizes = []
ofs = []
for d1 in f.dimensions:
if d1 in fixed:
ofs.append(fixed[d1])
else:
meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side)
ofs.append(meta.offset)
sizes.append(meta.size)
mapper[(d0, side, region)] = (sizes, ofs)

body = []
for d in f.dimensions:
if d in fixed:
continue

name = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
rpeer = FieldFromPointer(name, nb)
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
lpeer = FieldFromPointer(name, nb)

if (d, LEFT) in hse.halos:
# Sending to left, receiving from right
lsizes, lofs = mapper[(d, LEFT, OWNED)]
rsizes, rofs = mapper[(d, RIGHT, HALO)]
args = [f, lsizes, lofs, rofs, rpeer, lpeer, comm]
kwargs['haloid'] = len(body)
body.append(self._call_sendrecv(sendrecv.name, *args, **kwargs))

if (d, RIGHT) in hse.halos:
# Sending to right, receiving from left
rsizes, rofs = mapper[(d, RIGHT, OWNED)]
lsizes, lofs = mapper[(d, LEFT, HALO)]
args = [f, rsizes, rofs, lofs, lpeer, rpeer, comm]
kwargs['haloid'] = len(body)
body.append(self._call_sendrecv(sendrecv.name, *args, **kwargs))

iet = List(body=body)

parameters = list(f.handles) + [comm, nb] + list(fixed.values())

node = HaloUpdate('haloupdate%s' % key, iet, parameters)

node = node._rebuild(parameters=node.parameters + (kwargs['msg'],))

return node

def _call_haloupdate(self, name, f, hse, msg):
call = super()._call_haloupdate(name, f, hse)
call = call._rebuild(arguments=call.arguments + (msg,))
return call


class Overlap2HaloExchangeBuilder(OverlapHaloExchangeBuilder):

"""
Expand Down

0 comments on commit befd078

Please sign in to comment.