Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Nov 14, 2023
1 parent 3beaeb0 commit 2881c05
Showing 1 changed file with 59 additions and 35 deletions.
94 changes: 59 additions & 35 deletions pyop2/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,26 +177,40 @@ def delcomm_outer(comm, keyval, icomm):
"""
# This will raise errors at cleanup time as some objects are already
# deleted, so we just skip
if not PYOP2_FINALIZED:
if keyval not in (innercomm_keyval, compilationcomm_keyval):
raise PyOP2CommError("Unexpected keyval")
ocomm = icomm.Get_attr(outercomm_keyval)
if ocomm is None:
raise PyOP2CommError("Inner comm does not have expected reference to outer comm")

if ocomm != comm:
raise PyOP2CommError("Inner comm has reference to non-matching outer comm")
icomm.Delete_attr(outercomm_keyval)

# Once we have removed the reference to the inner/compilation comm we can free it
cidx = icomm.Get_attr(cidx_keyval)
cidx = cidx[0]
del _DUPED_COMM_DICT[cidx]
gc.collect()
refcount = icomm.Get_attr(refcount_keyval)
if refcount[0] > 1:
raise PyOP2CommError("References to comm still held, this will cause deadlock")
icomm.Free()
# ~ if not PYOP2_FINALIZED:
if keyval not in (innercomm_keyval, compilationcomm_keyval):
raise PyOP2CommError("Unexpected keyval")

if keyval == innercomm_keyval:
debug(f'Deleting innercomm keyval on {comm.name}')
if keyval == compilationcomm_keyval:
debug(f'Deleting compilationcomm keyval on {comm.name}')

# Let's be charitable and assume the comm has already been destroyed
# ~ if icomm == MPI.COMM_NULL:
ocomm = icomm.Get_attr(outercomm_keyval)
if ocomm is None:
raise PyOP2CommError("Inner comm does not have expected reference to outer comm")

if ocomm != comm:
raise PyOP2CommError("Inner comm has reference to non-matching outer comm")
icomm.Delete_attr(outercomm_keyval)

comp_comm = icomm.Get_attr(compilationcomm_keyval)
if comp_comm is not None:
debug('Removing compilation comm on inner comm')
decref(comp_comm)
icomm.Delete_attr(compilationcomm_keyval)

# Once we have removed the reference to the inner/compilation comm we can free it
cidx = icomm.Get_attr(cidx_keyval)
cidx = cidx[0]
del _DUPED_COMM_DICT[cidx]
gc.collect()
refcount = icomm.Get_attr(refcount_keyval)
if refcount[0] > 1 and not PYOP2_FINALIZED:
raise PyOP2CommError("References to comm still held, this will cause deadlock")
icomm.Free()


# Reference count, creation index, inner/outer/compilation communicator
Expand All @@ -219,10 +233,10 @@ def is_pyop2_comm(comm):
if isinstance(comm, PETSc.Comm):
ispyop2comm = False
elif comm == MPI.COMM_NULL:
if not PYOP2_FINALIZED:
raise PyOP2CommError("Communicator passed to is_pyop2_comm() is COMM_NULL")
else:
ispyop2comm = True
# ~ if not PYOP2_FINALIZED:
raise PyOP2CommError("Communicator passed to is_pyop2_comm() is COMM_NULL")
# ~ else:
# ~ ispyop2comm = True
elif isinstance(comm, MPI.Comm):
ispyop2comm = bool(comm.Get_attr(refcount_keyval))
else:
Expand Down Expand Up @@ -307,23 +321,25 @@ def incref(comm):
assert is_pyop2_comm(comm)
refcount = comm.Get_attr(refcount_keyval)
refcount[0] += 1
print(f"INCREF {comm.name} TO {refcount[0]}")


def decref(comm):
""" Decrement communicator reference count
"""
if not PYOP2_FINALIZED:
assert is_pyop2_comm(comm)
refcount = comm.Get_attr(refcount_keyval)
refcount[0] -= 1
if refcount[0] == 1:
# Freeing the comm is handled by the destruction of the user comm
pass
elif refcount[0] < 1:
raise PyOP2CommError("Reference count is less than 1, decref called too many times")
# ~ if not PYOP2_FINALIZED:
assert is_pyop2_comm(comm)
refcount = comm.Get_attr(refcount_keyval)
refcount[0] -= 1
print(f"DECREF {comm.name} TO {refcount[0]}")
if refcount[0] == 1:
# Freeing the comm is handled by the destruction of the user comm
pass
elif refcount[0] < 1:
raise PyOP2CommError("Reference count is less than 1, decref called too many times")

elif comm != MPI.COMM_NULL:
comm.Free()
# ~ elif comm != MPI.COMM_NULL:
# ~ comm.Free()


def dup_comm(comm_in):
Expand Down Expand Up @@ -479,12 +495,20 @@ def _free_comms():
debug = lambda string: print(string)
debug("PyOP2 Finalizing")
# Collect garbage as it may hold on to communicator references

debug("Calling gc.collect()")
gc.collect()
debug("STATE0")
debug(pyop2_comm_status())

debug("Freeing PYOP2_COMM_WORLD")
COMM_WORLD.Free()
debug("STATE1")
debug(pyop2_comm_status())

debug("Freeing PYOP2_COMM_SELF")
COMM_SELF.Free()
debug("STATE2")
debug(pyop2_comm_status())
debug(f"Freeing comms in list (length {len(_DUPED_COMM_DICT)})")
for key in sorted(_DUPED_COMM_DICT.keys()):
Expand Down

0 comments on commit 2881c05

Please sign in to comment.