diff --git a/pyop2/caching.py b/pyop2/caching.py index 24a3f5513..0f036212f 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -83,13 +83,6 @@ class ObjectCached(object): details). The object on which the cache is stored should contain a dict in its ``_cache`` attribute. - .. warning :: - - This kind of cache sets up a circular reference. If either of - the objects implements ``__del__``, the Python garbage - collector will not be able to collect this cycle, and hence - the cache will never be evicted. - .. warning:: The derived class' :meth:`__init__` is still called if the diff --git a/pyop2/compilation.py b/pyop2/compilation.py index ffd01f3b5..794024a8d 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -188,14 +188,8 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co self._debug = configuration["debug"] # Compilation communicators are reference counted on the PyOP2 comm - self.pcomm = mpi.internal_comm(comm) - self.comm = mpi.compilation_comm(self.pcomm) - - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - if hasattr(self, "pcomm"): - mpi.decref(self.pcomm) + self.pcomm = mpi.internal_comm(comm, self) + self.comm = mpi.compilation_comm(self.pcomm, self) def __repr__(self): return f"<{self._name} compiler, version {self.version or 'unknown'}>" diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 955bc59a7..04652ee03 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -42,6 +42,7 @@ import glob import os import tempfile +import weakref from pyop2.configuration import configuration from pyop2.exceptions import CompilationError @@ -266,10 +267,7 @@ class temp_internal_comm: """ def __init__(self, comm): self.user_comm = comm - self.internal_comm = internal_comm(self.user_comm) - - def __del__(self): - decref(self.internal_comm) + self.internal_comm = internal_comm(self.user_comm, self) def __enter__(self): """ Returns an internal comm that will be safely decref'd @@ -283,10 +281,12 @@ def __exit__(self, exc_type, exc_value, traceback): pass -def internal_comm(comm): +def internal_comm(comm, obj): """ Creates an internal comm from the user comm. If comm is None, create an internal communicator from COMM_WORLD :arg comm: A communicator or None + :arg obj: The object which the comm is an attribute of + (usually `self`) :returns pyop2_comm: A PyOP2 internal communicator """ @@ -309,6 +309,7 @@ def internal_comm(comm): pyop2_comm = comm else: pyop2_comm = dup_comm(comm) + weakref.finalize(obj, decref, pyop2_comm) return pyop2_comm @@ -445,10 +446,13 @@ def set_compilation_comm(comm, comp_comm): @collective -def compilation_comm(comm): +def compilation_comm(comm, obj): """Get a communicator for compilation. :arg comm: The input communicator, must be a PyOP2 comm. + :arg obj: The object which the comm is an attribute of + (usually `self`) + :returns: A communicator used for compilation (may be smaller) """ if not is_pyop2_comm(comm): @@ -470,10 +474,20 @@ def compilation_comm(comm): else: comp_comm = comm incref(comp_comm) + weakref.finalize(obj, decref, comp_comm) return comp_comm def finalize_safe_debug(): + ''' Return function for debug output. + + When Python is finalizing the logging module may be finalized before we have + finished writing debug information. In this case we fall back to using the + Python `print` function to output debugging information. + + Furthermore, we always want to see this finalization information when + running the CI tests. + ''' if PYOP2_FINALIZED: if logger.level > DEBUG and not _running_on_ci: debug = lambda string: None diff --git a/pyop2/parloop.py b/pyop2/parloop.py index 776b58c8d..cf96ba5b4 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -151,13 +151,9 @@ def __init__(self, global_knl, iterset, arguments): self.global_kernel = global_knl self.iterset = iterset - self.comm = mpi.internal_comm(iterset.comm) + self.comm = mpi.internal_comm(iterset.comm, self) self.arguments, self.reduced_globals = self.prepare_reduced_globals(arguments, global_knl) - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - @property def local_kernel(self): return self.global_kernel.local_kernel diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 826921e67..5ee339bcc 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -82,17 +82,13 @@ def __init__(self, dataset, data=None, dtype=None, name=None): EmptyDataMixin.__init__(self, data, dtype, self._shape) self._dataset = dataset - self.comm = mpi.internal_comm(dataset.comm) + self.comm = mpi.internal_comm(dataset.comm, self) self.halo_valid = True self._name = name or "dat_#x%x" % id(self) self._halo_frozen = False self._frozen_access_mode = None - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - @utils.cached_property def _kernel_args_(self): return (self._data.ctypes.data, ) @@ -823,7 +819,7 @@ def what(x): if not all(d.dtype == self._dats[0].dtype for d in self._dats): raise ex.DataValueError('MixedDat with different dtypes is not supported') # TODO: Think about different communicators on dats (c.f. MixedSet) - self.comm = mpi.internal_comm(self._dats[0].comm) + self.comm = mpi.internal_comm(self._dats[0].comm, self) @property def dat_version(self): diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 4e114032a..5b02f27ca 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -29,19 +29,13 @@ def __init__(self, iter_set, dim=1, name=None): return if isinstance(iter_set, Subset): raise NotImplementedError("Deriving a DataSet from a Subset is unsupported") - self.comm = mpi.internal_comm(iter_set.comm) + self.comm = mpi.internal_comm(iter_set.comm, self) self._set = iter_set self._dim = utils.as_tuple(dim, numbers.Integral) self._cdim = np.prod(self._dim).item() self._name = name or "dset_#x%x" % id(self) self._initialized = True - def __del__(self): - # Cannot use hasattr here, since we define `__getattr__` - # This causes infinite recursion when looked up! - if "comm" in self.__dict__: - mpi.decref(self.comm) - @classmethod def _process_args(cls, *args, **kwargs): return (args[0], ) + args, kwargs @@ -211,7 +205,7 @@ def __init__(self, global_): if self._initialized: return self._global = global_ - self.comm = mpi.internal_comm(global_.comm) + self.comm = mpi.internal_comm(global_.comm, self) self._globalset = GlobalSet(comm=self.comm) self._name = "gdset_#x%x" % id(self) self._initialized = True @@ -381,7 +375,7 @@ def __init__(self, arg, dims=None): comm = self._process_args(arg, dims)[0][0].comm except AttributeError: comm = None - self.comm = mpi.internal_comm(comm) + self.comm = mpi.internal_comm(comm, self) self._initialized = True @classmethod diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index daacc6a64..d8ed99134 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -26,10 +26,6 @@ def __init__(self, dim, data=None, dtype=None, name=None): self._buf = np.empty(self.shape, dtype=self.dtype) self._name = name or "%s_#x%x" % (self.__class__.__name__.lower(), id(self)) - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - @utils.cached_property def _kernel_args_(self): return (self._data.ctypes.data, ) @@ -247,16 +243,12 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None): super().__init__(dim, data, dtype, name) if comm is None: warnings.warn("PyOP2.Global has no comm, this is likely to break in parallel!") - self.comm = mpi.internal_comm(comm) + self.comm = mpi.internal_comm(comm, self) # Object versioning setup petsc_counter = (comm and self.dtype == PETSc.ScalarType) VecAccessMixin.__init__(self, petsc_counter=petsc_counter) - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - def __str__(self): return "OP2 Global Argument: %s with dim %s and value %s" \ % (self._name, self._dim, self._data) diff --git a/pyop2/types/map.py b/pyop2/types/map.py index 91224d52a..9d9ca48ae 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -36,7 +36,7 @@ class Map: def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, offset_quotient=None): self._iterset = iterset self._toset = toset - self.comm = mpi.internal_comm(toset.comm) + self.comm = mpi.internal_comm(toset.comm, self) self._arity = arity self._values = utils.verify_reshape(values, dtypes.IntType, (iterset.total_size, arity), allow_none=True) @@ -53,10 +53,6 @@ def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, o # A cache for objects built on top of this map self._cache = {} - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - @utils.cached_property def _kernel_args_(self): return (self._values.ctypes.data, ) @@ -200,7 +196,7 @@ def __init__(self, map_, permutation): if isinstance(map_, ComposedMap): raise NotImplementedError("PermutedMap of ComposedMap not implemented: simply permute before composing") self.map_ = map_ - self.comm = mpi.internal_comm(map_.comm) + self.comm = mpi.internal_comm(map_.comm, self) self.permutation = np.asarray(permutation, dtype=Map.dtype) assert (np.unique(permutation) == np.arange(map_.arity, dtype=Map.dtype)).all() @@ -251,7 +247,7 @@ def __init__(self, *maps_, name=None): raise ex.MapTypeError("frommap.arity must be 1") self._iterset = maps_[-1].iterset self._toset = maps_[0].toset - self.comm = mpi.internal_comm(self._toset.comm) + self.comm = mpi.internal_comm(self._toset.comm, self) self._arity = maps_[0].arity # Don't call super().__init__() to avoid calling verify_reshape() self._values = None @@ -315,7 +311,7 @@ def __init__(self, maps): raise ex.MapTypeError("All maps needs to share a communicator") if len(comms) == 0: raise ex.MapTypeError("Don't know how to make communicator") - self.comm = mpi.internal_comm(comms[0]) + self.comm = mpi.internal_comm(comms[0], self) self._initialized = True @classmethod diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index aefd77de1..035fa19b1 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -68,11 +68,17 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, self._o_nnz = None self._nrows = None if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].toset.size self._ncols = None if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].toset.size - self.lcomm = mpi.internal_comm(dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm) - self.rcomm = mpi.internal_comm(dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm) + self.lcomm = mpi.internal_comm( + dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm, + self + ) + self.rcomm = mpi.internal_comm( + dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm, + self + ) else: - self.lcomm = mpi.internal_comm(self._rmaps[0].comm) - self.rcomm = mpi.internal_comm(self._cmaps[0].comm) + self.lcomm = mpi.internal_comm(self._rmaps[0].comm, self) + self.rcomm = mpi.internal_comm(self._cmaps[0].comm, self) rset, cset = self.dsets # All rmaps and cmaps have the same data set - just use the first. @@ -93,7 +99,7 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, if self.lcomm != self.rcomm: raise ValueError("Haven't thought hard enough about different left and right communicators") - self.comm = mpi.internal_comm(self.lcomm) + self.comm = mpi.internal_comm(self.lcomm, self) self._name = name or "sparsity_#x%x" % id(self) self.iteration_regions = iteration_regions # If the Sparsity is defined on MixedDataSets, we need to build each @@ -129,14 +135,6 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, self._blocks = [[self]] self._initialized = True - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - if hasattr(self, "lcomm"): - mpi.decref(self.lcomm) - if hasattr(self, "rcomm"): - mpi.decref(self.rcomm) - _cache = {} @classmethod @@ -383,10 +381,10 @@ def __init__(self, parent, i, j): self._dims = tuple([tuple([parent.dims[i][j]])]) self._blocks = [[self]] self.iteration_regions = parent.iteration_regions - self.lcomm = mpi.internal_comm(self.dsets[0].comm) - self.rcomm = mpi.internal_comm(self.dsets[1].comm) + self.lcomm = mpi.internal_comm(self.dsets[0].comm, self) + self.rcomm = mpi.internal_comm(self.dsets[1].comm, self) # TODO: think about lcomm != rcomm - self.comm = mpi.internal_comm(self.lcomm) + self.comm = mpi.internal_comm(self.lcomm, self) self._initialized = True @classmethod @@ -445,22 +443,14 @@ class AbstractMat(DataCarrier, abc.ABC): ('name', str, ex.NameTypeError)) def __init__(self, sparsity, dtype=None, name=None): self._sparsity = sparsity - self.lcomm = mpi.internal_comm(sparsity.lcomm) - self.rcomm = mpi.internal_comm(sparsity.rcomm) - self.comm = mpi.internal_comm(sparsity.comm) + self.lcomm = mpi.internal_comm(sparsity.lcomm, self) + self.rcomm = mpi.internal_comm(sparsity.rcomm, self) + self.comm = mpi.internal_comm(sparsity.comm, self) dtype = dtype or dtypes.ScalarType self._datatype = np.dtype(dtype) self._name = name or "mat_#x%x" % id(self) self.assembly_state = Mat.ASSEMBLED - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - if hasattr(self, "lcomm"): - mpi.decref(self.lcomm) - if hasattr(self, "rcomm"): - mpi.decref(self.rcomm) - @utils.validate_in(('access', _modes, ex.ModeValueError)) def __call__(self, access, path, lgmaps=None, unroll_map=False): from pyop2.parloop import MatLegacyArg, MixedMatLegacyArg @@ -958,7 +948,7 @@ def __init__(self, parent, i, j): colis = cset.local_ises[j] self.handle = parent.handle.getLocalSubMatrix(isrow=rowis, iscol=colis) - self.comm = mpi.internal_comm(parent.comm) + self.comm = mpi.internal_comm(parent.comm, self) self.local_to_global_maps = self.handle.getLGMap() @property diff --git a/pyop2/types/set.py b/pyop2/types/set.py index f6b09e9d2..25abdf93c 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -1,8 +1,8 @@ import ctypes -import functools import numbers import numpy as np +import pytools from pyop2 import ( caching, @@ -65,7 +65,7 @@ def _wrapper_cache_key_(self): @utils.validate_type(('size', (numbers.Integral, tuple, list, np.ndarray), ex.SizeTypeError), ('name', str, ex.NameTypeError)) def __init__(self, size, name=None, halo=None, comm=None): - self.comm = mpi.internal_comm(comm) + self.comm = mpi.internal_comm(comm, self) if isinstance(size, numbers.Integral): size = [size] * 3 size = utils.as_tuple(size, numbers.Integral, 3) @@ -78,12 +78,6 @@ def __init__(self, size, name=None, halo=None, comm=None): # A cache of objects built on top of this set self._cache = {} - def __del__(self): - # Cannot use hasattr here, since child classes define `__getattr__` - # This causes infinite recursion when looked up! - if "comm" in self.__dict__: - mpi.decref(self.comm) - @utils.cached_property def core_size(self): """Core set size. Owned elements not touching halo elements.""" @@ -233,7 +227,7 @@ class GlobalSet(Set): _argtypes_ = () def __init__(self, comm=None): - self.comm = mpi.internal_comm(comm) + self.comm = mpi.internal_comm(comm, self) self._cache = {} @utils.cached_property @@ -318,7 +312,7 @@ class ExtrudedSet(Set): @utils.validate_type(('parent', Set, TypeError)) def __init__(self, parent, layers, extruded_periodic=False): self._parent = parent - self.comm = mpi.internal_comm(parent.comm) + self.comm = mpi.internal_comm(parent.comm, self) try: layers = utils.verify_reshape(layers, dtypes.IntType, (parent.total_size, 2)) self.constant_layers = False @@ -399,7 +393,7 @@ class Subset(ExtrudedSet): @utils.validate_type(('superset', Set, TypeError), ('indices', (list, tuple, np.ndarray), TypeError)) def __init__(self, superset, indices): - self.comm = mpi.internal_comm(superset.comm) + self.comm = mpi.internal_comm(superset.comm, self) # sort and remove duplicates indices = np.unique(indices) @@ -543,13 +537,12 @@ def __init__(self, sets): assert all(s is None or isinstance(s, GlobalSet) or ((s.layers == self._sets[0].layers).all() if s.layers is not None else True) for s in sets), \ "All components of a MixedSet must have the same number of layers." # TODO: do all sets need the same communicator? - self.comm = mpi.internal_comm(functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.comm, sets))) + self.comm = mpi.internal_comm( + pytools.single_valued(s.comm for s in sets if s is not None), + self + ) self._initialized = True - def __del__(self): - if self._initialized and hasattr(self, "comm"): - mpi.decref(self.comm) - @utils.cached_property def _kernel_args_(self): raise NotImplementedError diff --git a/requirements-ext.txt b/requirements-ext.txt index 0f19e0d06..2ccb04374 100644 --- a/requirements-ext.txt +++ b/requirements-ext.txt @@ -8,3 +8,4 @@ decorator<=4.4.2 dataclasses cachetools packaging +pytools diff --git a/setup.py b/setup.py index ad9f7815b..06c03e152 100644 --- a/setup.py +++ b/setup.py @@ -89,6 +89,7 @@ def get_petsc_dir(): 'decorator', 'mpi4py', 'numpy>=1.6', + 'pytools', ] version = sys.version_info[:2] diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index 1c43ce52f..1b95abebc 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -536,8 +536,7 @@ def myfunc(arg): def collective_key(self, *args): """Return a cache key suitable for use when collective over a communicator.""" - # Explicitly `mpi.decref(self.comm)` in any test that uses this comm - self.comm = mpi.internal_comm(mpi.COMM_SELF) + self.comm = mpi.internal_comm(mpi.COMM_SELF, self) return self.comm, cachetools.keys.hashkey(*args) @pytest.fixture @@ -575,7 +574,6 @@ def test_decorator_collective_has_different_in_memory_key(self, cache, cachedir) assert obj1 == obj2 and obj1 is not obj2 assert len(cache) == 2 assert len(os.listdir(cachedir.name)) == 1 - mpi.decref(self.comm) def test_decorator_disk_cache_reuses_results(self, cache, cachedir): decorated_func = disk_cached(cache, cachedir.name)(self.myfunc)