From ad0c4303d89f124479fe9124ac12c9b30bb823f4 Mon Sep 17 00:00:00 2001 From: Jack Betteridge <43041811+JDBetteridge@users.noreply.github.com> Date: Wed, 17 Jan 2024 16:38:11 +0000 Subject: [PATCH] Comm reference fixes + Remove __del__ method and add weakref.finalizer (#712) --- .github/workflows/ci.yml | 1 + pyop2/caching.py | 7 -- pyop2/compilation.py | 10 +-- pyop2/mpi.py | 139 +++++++++++++++++++++++++------------- pyop2/parloop.py | 6 +- pyop2/types/dat.py | 8 +-- pyop2/types/dataset.py | 12 +--- pyop2/types/glob.py | 10 +-- pyop2/types/map.py | 12 ++-- pyop2/types/mat.py | 46 +++++-------- pyop2/types/set.py | 25 +++---- requirements-ext.txt | 1 + setup.py | 1 + test/unit/test_caching.py | 4 +- 14 files changed, 136 insertions(+), 146 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8a1c600af..788186ac9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,6 +24,7 @@ jobs: PETSC_ARCH: default PETSC_CONFIGURE_OPTIONS: --with-debugging=1 --with-shared-libraries=1 --with-c2html=0 --with-fortran-bindings=0 RDMAV_FORK_SAFE: 1 + PYOP2_CI_TESTS: 1 timeout-minutes: 60 steps: diff --git a/pyop2/caching.py b/pyop2/caching.py index 24a3f5513..0f036212f 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -83,13 +83,6 @@ class ObjectCached(object): details). The object on which the cache is stored should contain a dict in its ``_cache`` attribute. - .. warning :: - - This kind of cache sets up a circular reference. If either of - the objects implements ``__del__``, the Python garbage - collector will not be able to collect this cycle, and hence - the cache will never be evicted. - .. warning:: The derived class' :meth:`__init__` is still called if the diff --git a/pyop2/compilation.py b/pyop2/compilation.py index ffd01f3b5..794024a8d 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -188,14 +188,8 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co self._debug = configuration["debug"] # Compilation communicators are reference counted on the PyOP2 comm - self.pcomm = mpi.internal_comm(comm) - self.comm = mpi.compilation_comm(self.pcomm) - - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - if hasattr(self, "pcomm"): - mpi.decref(self.pcomm) + self.pcomm = mpi.internal_comm(comm, self) + self.comm = mpi.compilation_comm(self.pcomm, self) def __repr__(self): return f"<{self._name} compiler, version {self.version or 'unknown'}>" diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 4b65f2e95..a84fa2b51 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -42,6 +42,7 @@ import glob import os import tempfile +import weakref from pyop2.configuration import configuration from pyop2.exceptions import CompilationError @@ -74,6 +75,8 @@ _DUPED_COMM_DICT = {} # Flag to indicate whether we are in cleanup (at exit) PYOP2_FINALIZED = False +# Flag for outputting information at the end of testing (do not abuse!) +_running_on_ci = bool(os.environ.get('PYOP2_CI_TESTS')) class PyOP2CommError(ValueError): @@ -175,28 +178,46 @@ def delcomm_outer(comm, keyval, icomm): :arg icomm: The inner communicator, should have a reference to ``comm``. """ - # This will raise errors at cleanup time as some objects are already - # deleted, so we just skip - if not PYOP2_FINALIZED: - if keyval not in (innercomm_keyval, compilationcomm_keyval): - raise PyOP2CommError("Unexpected keyval") - ocomm = icomm.Get_attr(outercomm_keyval) - if ocomm is None: - raise PyOP2CommError("Inner comm does not have expected reference to outer comm") - - if ocomm != comm: - raise PyOP2CommError("Inner comm has reference to non-matching outer comm") - icomm.Delete_attr(outercomm_keyval) - - # Once we have removed the reference to the inner/compilation comm we can free it - cidx = icomm.Get_attr(cidx_keyval) - cidx = cidx[0] - del _DUPED_COMM_DICT[cidx] - gc.collect() - refcount = icomm.Get_attr(refcount_keyval) - if refcount[0] > 1: - raise PyOP2CommError("References to comm still held, this will cause deadlock") - icomm.Free() + # Use debug printer that is safe to use at exit time + debug = finalize_safe_debug() + if keyval not in (innercomm_keyval, compilationcomm_keyval): + raise PyOP2CommError("Unexpected keyval") + + if keyval == innercomm_keyval: + debug(f'Deleting innercomm keyval on {comm.name}') + if keyval == compilationcomm_keyval: + debug(f'Deleting compilationcomm keyval on {comm.name}') + + ocomm = icomm.Get_attr(outercomm_keyval) + if ocomm is None: + raise PyOP2CommError("Inner comm does not have expected reference to outer comm") + + if ocomm != comm: + raise PyOP2CommError("Inner comm has reference to non-matching outer comm") + icomm.Delete_attr(outercomm_keyval) + + # An inner comm may or may not hold a reference to a compilation comm + comp_comm = icomm.Get_attr(compilationcomm_keyval) + if comp_comm is not None: + debug('Removing compilation comm on inner comm') + decref(comp_comm) + icomm.Delete_attr(compilationcomm_keyval) + + # Once we have removed the reference to the inner/compilation comm we can free it + cidx = icomm.Get_attr(cidx_keyval) + cidx = cidx[0] + del _DUPED_COMM_DICT[cidx] + gc.collect() + refcount = icomm.Get_attr(refcount_keyval) + if refcount[0] > 1: + # In the case where `comm` is a custom user communicator there may be references + # to the inner comm still held and this is not an issue, but there is not an + # easy way to distinguish this case, so we just log the event. + debug( + f"There are still {refcount[0]} references to {comm.name}, " + "this will cause deadlock if the communicator has been incorrectly freed" + ) + icomm.Free() # Reference count, creation index, inner/outer/compilation communicator @@ -215,14 +236,10 @@ def is_pyop2_comm(comm): :arg comm: Communicator to query """ - global PYOP2_FINALIZED if isinstance(comm, PETSc.Comm): ispyop2comm = False elif comm == MPI.COMM_NULL: - if not PYOP2_FINALIZED: - raise PyOP2CommError("Communicator passed to is_pyop2_comm() is COMM_NULL") - else: - ispyop2comm = True + raise PyOP2CommError("Communicator passed to is_pyop2_comm() is COMM_NULL") elif isinstance(comm, MPI.Comm): ispyop2comm = bool(comm.Get_attr(refcount_keyval)) else: @@ -231,7 +248,8 @@ def is_pyop2_comm(comm): def pyop2_comm_status(): - """ Prints the reference counts for all comms PyOP2 has duplicated + """ Return string containing a table of the reference counts for all + communicators PyOP2 has duplicated. """ status_string = 'PYOP2 Communicator reference counts:\n' status_string += '| Communicator name | Count |\n' @@ -255,10 +273,7 @@ class temp_internal_comm: """ def __init__(self, comm): self.user_comm = comm - self.internal_comm = internal_comm(self.user_comm) - - def __del__(self): - decref(self.internal_comm) + self.internal_comm = internal_comm(self.user_comm, self) def __enter__(self): """ Returns an internal comm that will be safely decref'd @@ -272,10 +287,12 @@ def __exit__(self, exc_type, exc_value, traceback): pass -def internal_comm(comm): +def internal_comm(comm, obj): """ Creates an internal comm from the user comm. If comm is None, create an internal communicator from COMM_WORLD :arg comm: A communicator or None + :arg obj: The object which the comm is an attribute of + (usually `self`) :returns pyop2_comm: A PyOP2 internal communicator """ @@ -298,6 +315,7 @@ def internal_comm(comm): pyop2_comm = comm else: pyop2_comm = dup_comm(comm) + weakref.finalize(obj, decref, pyop2_comm) return pyop2_comm @@ -312,19 +330,18 @@ def incref(comm): def decref(comm): """ Decrement communicator reference count """ - if not PYOP2_FINALIZED: + if comm == MPI.COMM_NULL: + # This case occurs if the the outer communicator has already been freed by + # the user + debug("Cannot decref an already freed communicator") + else: assert is_pyop2_comm(comm) refcount = comm.Get_attr(refcount_keyval) refcount[0] -= 1 - if refcount[0] == 1: - # Freeing the comm is handled by the destruction of the user comm - pass - elif refcount[0] < 1: + # Freeing the internal comm is handled by the destruction of the user comm + if refcount[0] < 1: raise PyOP2CommError("Reference count is less than 1, decref called too many times") - elif comm != MPI.COMM_NULL: - comm.Free() - def dup_comm(comm_in): """Given a communicator return a communicator for internal use. @@ -440,10 +457,13 @@ def set_compilation_comm(comm, comp_comm): @collective -def compilation_comm(comm): +def compilation_comm(comm, obj): """Get a communicator for compilation. :arg comm: The input communicator, must be a PyOP2 comm. + :arg obj: The object which the comm is an attribute of + (usually `self`) + :returns: A communicator used for compilation (may be smaller) """ if not is_pyop2_comm(comm): @@ -465,29 +485,54 @@ def compilation_comm(comm): else: comp_comm = comm incref(comp_comm) + weakref.finalize(obj, decref, comp_comm) return comp_comm +def finalize_safe_debug(): + ''' Return function for debug output. + + When Python is finalizing the logging module may be finalized before we have + finished writing debug information. In this case we fall back to using the + Python `print` function to output debugging information. + + Furthermore, we always want to see this finalization information when + running the CI tests. + ''' + global debug + if PYOP2_FINALIZED: + if logger.level > DEBUG and not _running_on_ci: + debug = lambda string: None + else: + debug = lambda string: print(string) + return debug + + @atexit.register def _free_comms(): """Free all outstanding communicators.""" global PYOP2_FINALIZED PYOP2_FINALIZED = True - if logger.level > DEBUG: - debug = lambda string: None - else: - debug = lambda string: print(string) + debug = finalize_safe_debug() debug("PyOP2 Finalizing") # Collect garbage as it may hold on to communicator references + debug("Calling gc.collect()") gc.collect() + debug("STATE0") + debug(pyop2_comm_status()) + debug("Freeing PYOP2_COMM_WORLD") COMM_WORLD.Free() + debug("STATE1") + debug(pyop2_comm_status()) + debug("Freeing PYOP2_COMM_SELF") COMM_SELF.Free() + debug("STATE2") debug(pyop2_comm_status()) debug(f"Freeing comms in list (length {len(_DUPED_COMM_DICT)})") - for key in sorted(_DUPED_COMM_DICT.keys()): + for key in sorted(_DUPED_COMM_DICT.keys(), reverse=True): comm = _DUPED_COMM_DICT[key] if comm != MPI.COMM_NULL: refcount = comm.Get_attr(refcount_keyval) diff --git a/pyop2/parloop.py b/pyop2/parloop.py index 776b58c8d..cf96ba5b4 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -151,13 +151,9 @@ def __init__(self, global_knl, iterset, arguments): self.global_kernel = global_knl self.iterset = iterset - self.comm = mpi.internal_comm(iterset.comm) + self.comm = mpi.internal_comm(iterset.comm, self) self.arguments, self.reduced_globals = self.prepare_reduced_globals(arguments, global_knl) - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - @property def local_kernel(self): return self.global_kernel.local_kernel diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 826921e67..5ee339bcc 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -82,17 +82,13 @@ def __init__(self, dataset, data=None, dtype=None, name=None): EmptyDataMixin.__init__(self, data, dtype, self._shape) self._dataset = dataset - self.comm = mpi.internal_comm(dataset.comm) + self.comm = mpi.internal_comm(dataset.comm, self) self.halo_valid = True self._name = name or "dat_#x%x" % id(self) self._halo_frozen = False self._frozen_access_mode = None - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - @utils.cached_property def _kernel_args_(self): return (self._data.ctypes.data, ) @@ -823,7 +819,7 @@ def what(x): if not all(d.dtype == self._dats[0].dtype for d in self._dats): raise ex.DataValueError('MixedDat with different dtypes is not supported') # TODO: Think about different communicators on dats (c.f. MixedSet) - self.comm = mpi.internal_comm(self._dats[0].comm) + self.comm = mpi.internal_comm(self._dats[0].comm, self) @property def dat_version(self): diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index e554bbcef..8d3ba0472 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -29,19 +29,13 @@ def __init__(self, iter_set, dim=1, name=None): return if isinstance(iter_set, Subset): raise NotImplementedError("Deriving a DataSet from a Subset is unsupported") - self.comm = mpi.internal_comm(iter_set.comm) + self.comm = mpi.internal_comm(iter_set.comm, self) self._set = iter_set self._dim = utils.as_tuple(dim, numbers.Integral) self._cdim = np.prod(self._dim).item() self._name = name or "dset_#x%x" % id(self) self._initialized = True - def __del__(self): - # Cannot use hasattr here, since we define `__getattr__` - # This causes infinite recursion when looked up! - if "comm" in self.__dict__: - mpi.decref(self.comm) - @classmethod def _process_args(cls, *args, **kwargs): return (args[0], ) + args, kwargs @@ -211,7 +205,7 @@ def __init__(self, global_): if self._initialized: return self._global = global_ - self.comm = mpi.internal_comm(global_.comm) + self.comm = mpi.internal_comm(global_.comm, self) self._globalset = GlobalSet(comm=self.comm) self._name = "gdset_#x%x" % id(self) self._initialized = True @@ -360,7 +354,7 @@ def __init__(self, arg, dims=None): comm = self._process_args(arg, dims)[0][0].comm except AttributeError: comm = None - self.comm = mpi.internal_comm(comm) + self.comm = mpi.internal_comm(comm, self) self._initialized = True @classmethod diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index daacc6a64..d8ed99134 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -26,10 +26,6 @@ def __init__(self, dim, data=None, dtype=None, name=None): self._buf = np.empty(self.shape, dtype=self.dtype) self._name = name or "%s_#x%x" % (self.__class__.__name__.lower(), id(self)) - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - @utils.cached_property def _kernel_args_(self): return (self._data.ctypes.data, ) @@ -247,16 +243,12 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None): super().__init__(dim, data, dtype, name) if comm is None: warnings.warn("PyOP2.Global has no comm, this is likely to break in parallel!") - self.comm = mpi.internal_comm(comm) + self.comm = mpi.internal_comm(comm, self) # Object versioning setup petsc_counter = (comm and self.dtype == PETSc.ScalarType) VecAccessMixin.__init__(self, petsc_counter=petsc_counter) - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - def __str__(self): return "OP2 Global Argument: %s with dim %s and value %s" \ % (self._name, self._dim, self._data) diff --git a/pyop2/types/map.py b/pyop2/types/map.py index 91224d52a..9d9ca48ae 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -36,7 +36,7 @@ class Map: def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, offset_quotient=None): self._iterset = iterset self._toset = toset - self.comm = mpi.internal_comm(toset.comm) + self.comm = mpi.internal_comm(toset.comm, self) self._arity = arity self._values = utils.verify_reshape(values, dtypes.IntType, (iterset.total_size, arity), allow_none=True) @@ -53,10 +53,6 @@ def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, o # A cache for objects built on top of this map self._cache = {} - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - @utils.cached_property def _kernel_args_(self): return (self._values.ctypes.data, ) @@ -200,7 +196,7 @@ def __init__(self, map_, permutation): if isinstance(map_, ComposedMap): raise NotImplementedError("PermutedMap of ComposedMap not implemented: simply permute before composing") self.map_ = map_ - self.comm = mpi.internal_comm(map_.comm) + self.comm = mpi.internal_comm(map_.comm, self) self.permutation = np.asarray(permutation, dtype=Map.dtype) assert (np.unique(permutation) == np.arange(map_.arity, dtype=Map.dtype)).all() @@ -251,7 +247,7 @@ def __init__(self, *maps_, name=None): raise ex.MapTypeError("frommap.arity must be 1") self._iterset = maps_[-1].iterset self._toset = maps_[0].toset - self.comm = mpi.internal_comm(self._toset.comm) + self.comm = mpi.internal_comm(self._toset.comm, self) self._arity = maps_[0].arity # Don't call super().__init__() to avoid calling verify_reshape() self._values = None @@ -315,7 +311,7 @@ def __init__(self, maps): raise ex.MapTypeError("All maps needs to share a communicator") if len(comms) == 0: raise ex.MapTypeError("Don't know how to make communicator") - self.comm = mpi.internal_comm(comms[0]) + self.comm = mpi.internal_comm(comms[0], self) self._initialized = True @classmethod diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index a5ad65f71..b96594a1e 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -66,11 +66,17 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, self._dims = (((1, 1),),) self._d_nnz = None self._o_nnz = None - self.lcomm = mpi.internal_comm(dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm) - self.rcomm = mpi.internal_comm(dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm) + self.lcomm = mpi.internal_comm( + dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm, + self + ) + self.rcomm = mpi.internal_comm( + dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm, + self + ) else: - self.lcomm = mpi.internal_comm(self._rmaps[0].comm) - self.rcomm = mpi.internal_comm(self._cmaps[0].comm) + self.lcomm = mpi.internal_comm(self._rmaps[0].comm, self) + self.rcomm = mpi.internal_comm(self._cmaps[0].comm, self) rset, cset = self.dsets @@ -88,7 +94,7 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, if self.lcomm != self.rcomm: raise ValueError("Haven't thought hard enough about different left and right communicators") - self.comm = mpi.internal_comm(self.lcomm) + self.comm = mpi.internal_comm(self.lcomm, self) self._name = name or "sparsity_#x%x" % id(self) self.iteration_regions = iteration_regions # If the Sparsity is defined on MixedDataSets, we need to build each @@ -124,14 +130,6 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, self._blocks = [[self]] self._initialized = True - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - if hasattr(self, "lcomm"): - mpi.decref(self.lcomm) - if hasattr(self, "rcomm"): - mpi.decref(self.rcomm) - _cache = {} @classmethod @@ -366,10 +364,10 @@ def __init__(self, parent, i, j): self._dims = tuple([tuple([parent.dims[i][j]])]) self._blocks = [[self]] self.iteration_regions = parent.iteration_regions - self.lcomm = mpi.internal_comm(self.dsets[0].comm) - self.rcomm = mpi.internal_comm(self.dsets[1].comm) + self.lcomm = mpi.internal_comm(self.dsets[0].comm, self) + self.rcomm = mpi.internal_comm(self.dsets[1].comm, self) # TODO: think about lcomm != rcomm - self.comm = mpi.internal_comm(self.lcomm) + self.comm = mpi.internal_comm(self.lcomm, self) self._initialized = True @classmethod @@ -428,22 +426,14 @@ class AbstractMat(DataCarrier, abc.ABC): ('name', str, ex.NameTypeError)) def __init__(self, sparsity, dtype=None, name=None): self._sparsity = sparsity - self.lcomm = mpi.internal_comm(sparsity.lcomm) - self.rcomm = mpi.internal_comm(sparsity.rcomm) - self.comm = mpi.internal_comm(sparsity.comm) + self.lcomm = mpi.internal_comm(sparsity.lcomm, self) + self.rcomm = mpi.internal_comm(sparsity.rcomm, self) + self.comm = mpi.internal_comm(sparsity.comm, self) dtype = dtype or dtypes.ScalarType self._datatype = np.dtype(dtype) self._name = name or "mat_#x%x" % id(self) self.assembly_state = Mat.ASSEMBLED - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - if hasattr(self, "lcomm"): - mpi.decref(self.lcomm) - if hasattr(self, "rcomm"): - mpi.decref(self.rcomm) - @utils.validate_in(('access', _modes, ex.ModeValueError)) def __call__(self, access, path, lgmaps=None, unroll_map=False): from pyop2.parloop import MatLegacyArg, MixedMatLegacyArg @@ -943,7 +933,7 @@ def __init__(self, parent, i, j): colis = cset.local_ises[j] self.handle = parent.handle.getLocalSubMatrix(isrow=rowis, iscol=colis) - self.comm = mpi.internal_comm(parent.comm) + self.comm = mpi.internal_comm(parent.comm, self) self.local_to_global_maps = self.handle.getLGMap() @property diff --git a/pyop2/types/set.py b/pyop2/types/set.py index f6b09e9d2..25abdf93c 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -1,8 +1,8 @@ import ctypes -import functools import numbers import numpy as np +import pytools from pyop2 import ( caching, @@ -65,7 +65,7 @@ def _wrapper_cache_key_(self): @utils.validate_type(('size', (numbers.Integral, tuple, list, np.ndarray), ex.SizeTypeError), ('name', str, ex.NameTypeError)) def __init__(self, size, name=None, halo=None, comm=None): - self.comm = mpi.internal_comm(comm) + self.comm = mpi.internal_comm(comm, self) if isinstance(size, numbers.Integral): size = [size] * 3 size = utils.as_tuple(size, numbers.Integral, 3) @@ -78,12 +78,6 @@ def __init__(self, size, name=None, halo=None, comm=None): # A cache of objects built on top of this set self._cache = {} - def __del__(self): - # Cannot use hasattr here, since child classes define `__getattr__` - # This causes infinite recursion when looked up! - if "comm" in self.__dict__: - mpi.decref(self.comm) - @utils.cached_property def core_size(self): """Core set size. Owned elements not touching halo elements.""" @@ -233,7 +227,7 @@ class GlobalSet(Set): _argtypes_ = () def __init__(self, comm=None): - self.comm = mpi.internal_comm(comm) + self.comm = mpi.internal_comm(comm, self) self._cache = {} @utils.cached_property @@ -318,7 +312,7 @@ class ExtrudedSet(Set): @utils.validate_type(('parent', Set, TypeError)) def __init__(self, parent, layers, extruded_periodic=False): self._parent = parent - self.comm = mpi.internal_comm(parent.comm) + self.comm = mpi.internal_comm(parent.comm, self) try: layers = utils.verify_reshape(layers, dtypes.IntType, (parent.total_size, 2)) self.constant_layers = False @@ -399,7 +393,7 @@ class Subset(ExtrudedSet): @utils.validate_type(('superset', Set, TypeError), ('indices', (list, tuple, np.ndarray), TypeError)) def __init__(self, superset, indices): - self.comm = mpi.internal_comm(superset.comm) + self.comm = mpi.internal_comm(superset.comm, self) # sort and remove duplicates indices = np.unique(indices) @@ -543,13 +537,12 @@ def __init__(self, sets): assert all(s is None or isinstance(s, GlobalSet) or ((s.layers == self._sets[0].layers).all() if s.layers is not None else True) for s in sets), \ "All components of a MixedSet must have the same number of layers." # TODO: do all sets need the same communicator? - self.comm = mpi.internal_comm(functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.comm, sets))) + self.comm = mpi.internal_comm( + pytools.single_valued(s.comm for s in sets if s is not None), + self + ) self._initialized = True - def __del__(self): - if self._initialized and hasattr(self, "comm"): - mpi.decref(self.comm) - @utils.cached_property def _kernel_args_(self): raise NotImplementedError diff --git a/requirements-ext.txt b/requirements-ext.txt index 0f19e0d06..2ccb04374 100644 --- a/requirements-ext.txt +++ b/requirements-ext.txt @@ -8,3 +8,4 @@ decorator<=4.4.2 dataclasses cachetools packaging +pytools diff --git a/setup.py b/setup.py index ad9f7815b..06c03e152 100644 --- a/setup.py +++ b/setup.py @@ -89,6 +89,7 @@ def get_petsc_dir(): 'decorator', 'mpi4py', 'numpy>=1.6', + 'pytools', ] version = sys.version_info[:2] diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index 1c43ce52f..1b95abebc 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -536,8 +536,7 @@ def myfunc(arg): def collective_key(self, *args): """Return a cache key suitable for use when collective over a communicator.""" - # Explicitly `mpi.decref(self.comm)` in any test that uses this comm - self.comm = mpi.internal_comm(mpi.COMM_SELF) + self.comm = mpi.internal_comm(mpi.COMM_SELF, self) return self.comm, cachetools.keys.hashkey(*args) @pytest.fixture @@ -575,7 +574,6 @@ def test_decorator_collective_has_different_in_memory_key(self, cache, cachedir) assert obj1 == obj2 and obj1 is not obj2 assert len(cache) == 2 assert len(os.listdir(cachedir.name)) == 1 - mpi.decref(self.comm) def test_decorator_disk_cache_reuses_results(self, cache, cachedir): decorated_func = disk_cached(cache, cachedir.name)(self.myfunc)