From ca73fb6184f2b6cd8e32ad20c816d280597e8f9a Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 14 Nov 2023 14:27:54 +0000 Subject: [PATCH 1/6] Remove outdated information --- pyop2/caching.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 24a3f5513..0f036212f 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -83,13 +83,6 @@ class ObjectCached(object): details). The object on which the cache is stored should contain a dict in its ``_cache`` attribute. - .. warning :: - - This kind of cache sets up a circular reference. If either of - the objects implements ``__del__``, the Python garbage - collector will not be able to collect this cycle, and hence - the cache will never be evicted. - .. warning:: The derived class' :meth:`__init__` is still called if the From fee60d29be5325a62b6bab99197c026a5db7b970 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 14 Nov 2023 15:02:41 +0000 Subject: [PATCH 2/6] Replace __del__ with weakref.finalizer --- pyop2/compilation.py | 9 +++------ pyop2/mpi.py | 5 ++--- pyop2/parloop.py | 6 ++---- pyop2/types/dat.py | 7 +++---- pyop2/types/dataset.py | 10 ++++------ pyop2/types/glob.py | 12 +++++------- pyop2/types/map.py | 9 +++++---- pyop2/types/mat.py | 29 +++++++++++++---------------- pyop2/types/set.py | 16 ++++++---------- 9 files changed, 43 insertions(+), 60 deletions(-) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index ffd01f3b5..249c67bdb 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -42,6 +42,7 @@ import shlex from hashlib import md5 from packaging.version import Version, InvalidVersion +import weakref from pyop2 import mpi @@ -189,13 +190,9 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co # Compilation communicators are reference counted on the PyOP2 comm self.pcomm = mpi.internal_comm(comm) + weakref.finalize(self, mpi.decref, self.pcomm) self.comm = mpi.compilation_comm(self.pcomm) - - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - if hasattr(self, "pcomm"): - mpi.decref(self.pcomm) + weakref.finalize(self, mpi.decref, self.comm) def __repr__(self): return f"<{self._name} compiler, version {self.version or 'unknown'}>" diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 955bc59a7..2c2cd780e 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -42,6 +42,7 @@ import glob import os import tempfile +import weakref from pyop2.configuration import configuration from pyop2.exceptions import CompilationError @@ -267,9 +268,7 @@ class temp_internal_comm: def __init__(self, comm): self.user_comm = comm self.internal_comm = internal_comm(self.user_comm) - - def __del__(self): - decref(self.internal_comm) + weakref.finalize(self, decref, self.internal_comm) def __enter__(self): """ Returns an internal comm that will be safely decref'd diff --git a/pyop2/parloop.py b/pyop2/parloop.py index 776b58c8d..e3ae7aa85 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -3,6 +3,7 @@ import operator from dataclasses import dataclass from typing import Any, Optional, Tuple +import weakref import loopy as lp import numpy as np @@ -152,12 +153,9 @@ def __init__(self, global_knl, iterset, arguments): self.global_kernel = global_knl self.iterset = iterset self.comm = mpi.internal_comm(iterset.comm) + weakref.finalize(self, mpi.decref, self.comm) self.arguments, self.reduced_globals = self.prepare_reduced_globals(arguments, global_knl) - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - @property def local_kernel(self): return self.global_kernel.local_kernel diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 826921e67..9f055e3b2 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -3,6 +3,7 @@ import ctypes import itertools import operator +import weakref import loopy as lp import numpy as np @@ -83,16 +84,13 @@ def __init__(self, dataset, data=None, dtype=None, name=None): self._dataset = dataset self.comm = mpi.internal_comm(dataset.comm) + weakref.finalize(self, mpi.decref, self.comm) self.halo_valid = True self._name = name or "dat_#x%x" % id(self) self._halo_frozen = False self._frozen_access_mode = None - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - @utils.cached_property def _kernel_args_(self): return (self._data.ctypes.data, ) @@ -824,6 +822,7 @@ def what(x): raise ex.DataValueError('MixedDat with different dtypes is not supported') # TODO: Think about different communicators on dats (c.f. MixedSet) self.comm = mpi.internal_comm(self._dats[0].comm) + weakref.finalize(self, mpi.decref, self.comm) @property def dat_version(self): diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 4e114032a..870764934 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -1,4 +1,5 @@ import numbers +import weakref import numpy as np from petsc4py import PETSc @@ -30,18 +31,13 @@ def __init__(self, iter_set, dim=1, name=None): if isinstance(iter_set, Subset): raise NotImplementedError("Deriving a DataSet from a Subset is unsupported") self.comm = mpi.internal_comm(iter_set.comm) + weakref.finalize(self, mpi.decref, self.comm) self._set = iter_set self._dim = utils.as_tuple(dim, numbers.Integral) self._cdim = np.prod(self._dim).item() self._name = name or "dset_#x%x" % id(self) self._initialized = True - def __del__(self): - # Cannot use hasattr here, since we define `__getattr__` - # This causes infinite recursion when looked up! - if "comm" in self.__dict__: - mpi.decref(self.comm) - @classmethod def _process_args(cls, *args, **kwargs): return (args[0], ) + args, kwargs @@ -212,6 +208,7 @@ def __init__(self, global_): return self._global = global_ self.comm = mpi.internal_comm(global_.comm) + weakref.finalize(self, mpi.decref, self.comm) self._globalset = GlobalSet(comm=self.comm) self._name = "gdset_#x%x" % id(self) self._initialized = True @@ -382,6 +379,7 @@ def __init__(self, arg, dims=None): except AttributeError: comm = None self.comm = mpi.internal_comm(comm) + weakref.finalize(self, mpi.decref, self.comm) self._initialized = True @classmethod diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index daacc6a64..c80b7b2bc 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -2,6 +2,7 @@ import ctypes import operator import warnings +import weakref import numpy as np from petsc4py import PETSc @@ -26,9 +27,9 @@ def __init__(self, dim, data=None, dtype=None, name=None): self._buf = np.empty(self.shape, dtype=self.dtype) self._name = name or "%s_#x%x" % (self.__class__.__name__.lower(), id(self)) - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) + # ~ def __del__(self): # TODO !? + # ~ if hasattr(self, "comm"): + # ~ mpi.decref(self.comm) @utils.cached_property def _kernel_args_(self): @@ -248,15 +249,12 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None): if comm is None: warnings.warn("PyOP2.Global has no comm, this is likely to break in parallel!") self.comm = mpi.internal_comm(comm) + weakref.finalize(self, mpi.decref, self.comm) # Object versioning setup petsc_counter = (comm and self.dtype == PETSc.ScalarType) VecAccessMixin.__init__(self, petsc_counter=petsc_counter) - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - def __str__(self): return "OP2 Global Argument: %s with dim %s and value %s" \ % (self._name, self._dim, self._data) diff --git a/pyop2/types/map.py b/pyop2/types/map.py index 91224d52a..49d49e953 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -1,6 +1,7 @@ import itertools import functools import numbers +import weakref import numpy as np @@ -37,6 +38,7 @@ def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, o self._iterset = iterset self._toset = toset self.comm = mpi.internal_comm(toset.comm) + weakref.finalize(self, mpi.decref, self.comm) self._arity = arity self._values = utils.verify_reshape(values, dtypes.IntType, (iterset.total_size, arity), allow_none=True) @@ -53,10 +55,6 @@ def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, o # A cache for objects built on top of this map self._cache = {} - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - @utils.cached_property def _kernel_args_(self): return (self._values.ctypes.data, ) @@ -201,6 +199,7 @@ def __init__(self, map_, permutation): raise NotImplementedError("PermutedMap of ComposedMap not implemented: simply permute before composing") self.map_ = map_ self.comm = mpi.internal_comm(map_.comm) + weakref.finalize(self, mpi.decref, self.comm) self.permutation = np.asarray(permutation, dtype=Map.dtype) assert (np.unique(permutation) == np.arange(map_.arity, dtype=Map.dtype)).all() @@ -252,6 +251,7 @@ def __init__(self, *maps_, name=None): self._iterset = maps_[-1].iterset self._toset = maps_[0].toset self.comm = mpi.internal_comm(self._toset.comm) + weakref.finalize(self, mpi.decref, self.comm) self._arity = maps_[0].arity # Don't call super().__init__() to avoid calling verify_reshape() self._values = None @@ -316,6 +316,7 @@ def __init__(self, maps): if len(comms) == 0: raise ex.MapTypeError("Don't know how to make communicator") self.comm = mpi.internal_comm(comms[0]) + weakref.finalize(self, mpi.decref, self.comm) self._initialized = True @classmethod diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index aefd77de1..baf5e5a2b 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -1,6 +1,7 @@ import abc import ctypes import itertools +import weakref import numpy as np from petsc4py import PETSc @@ -69,10 +70,14 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, self._nrows = None if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].toset.size self._ncols = None if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].toset.size self.lcomm = mpi.internal_comm(dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm) + weakref.finalize(self, mpi.decref, self.lcomm) self.rcomm = mpi.internal_comm(dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm) + weakref.finalize(self, mpi.decref, self.rcomm) else: self.lcomm = mpi.internal_comm(self._rmaps[0].comm) + weakref.finalize(self, mpi.decref, self.lcomm) self.rcomm = mpi.internal_comm(self._cmaps[0].comm) + weakref.finalize(self, mpi.decref, self.rcomm) rset, cset = self.dsets # All rmaps and cmaps have the same data set - just use the first. @@ -94,6 +99,7 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, if self.lcomm != self.rcomm: raise ValueError("Haven't thought hard enough about different left and right communicators") self.comm = mpi.internal_comm(self.lcomm) + weakref.finalize(self, mpi.decref, self.comm) self._name = name or "sparsity_#x%x" % id(self) self.iteration_regions = iteration_regions # If the Sparsity is defined on MixedDataSets, we need to build each @@ -129,14 +135,6 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, self._blocks = [[self]] self._initialized = True - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - if hasattr(self, "lcomm"): - mpi.decref(self.lcomm) - if hasattr(self, "rcomm"): - mpi.decref(self.rcomm) - _cache = {} @classmethod @@ -384,9 +382,12 @@ def __init__(self, parent, i, j): self._blocks = [[self]] self.iteration_regions = parent.iteration_regions self.lcomm = mpi.internal_comm(self.dsets[0].comm) + weakref.finalize(self, mpi.decref, self.lcomm) self.rcomm = mpi.internal_comm(self.dsets[1].comm) + weakref.finalize(self, mpi.decref, self.rcomm) # TODO: think about lcomm != rcomm self.comm = mpi.internal_comm(self.lcomm) + weakref.finalize(self, mpi.decref, self.comm) self._initialized = True @classmethod @@ -446,21 +447,16 @@ class AbstractMat(DataCarrier, abc.ABC): def __init__(self, sparsity, dtype=None, name=None): self._sparsity = sparsity self.lcomm = mpi.internal_comm(sparsity.lcomm) + weakref.finalize(self, mpi.decref, self.lcomm) self.rcomm = mpi.internal_comm(sparsity.rcomm) + weakref.finalize(self, mpi.decref, self.rcomm) self.comm = mpi.internal_comm(sparsity.comm) + weakref.finalize(self, mpi.decref, self.comm) dtype = dtype or dtypes.ScalarType self._datatype = np.dtype(dtype) self._name = name or "mat_#x%x" % id(self) self.assembly_state = Mat.ASSEMBLED - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - if hasattr(self, "lcomm"): - mpi.decref(self.lcomm) - if hasattr(self, "rcomm"): - mpi.decref(self.rcomm) - @utils.validate_in(('access', _modes, ex.ModeValueError)) def __call__(self, access, path, lgmaps=None, unroll_map=False): from pyop2.parloop import MatLegacyArg, MixedMatLegacyArg @@ -959,6 +955,7 @@ def __init__(self, parent, i, j): self.handle = parent.handle.getLocalSubMatrix(isrow=rowis, iscol=colis) self.comm = mpi.internal_comm(parent.comm) + weakref.finalize(self, mpi.decref, self.comm) self.local_to_global_maps = self.handle.getLGMap() @property diff --git a/pyop2/types/set.py b/pyop2/types/set.py index f6b09e9d2..c17cf4812 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -1,6 +1,7 @@ import ctypes import functools import numbers +import weakref import numpy as np @@ -66,6 +67,7 @@ def _wrapper_cache_key_(self): ('name', str, ex.NameTypeError)) def __init__(self, size, name=None, halo=None, comm=None): self.comm = mpi.internal_comm(comm) + weakref.finalize(self, mpi.decref, self.comm) if isinstance(size, numbers.Integral): size = [size] * 3 size = utils.as_tuple(size, numbers.Integral, 3) @@ -78,12 +80,6 @@ def __init__(self, size, name=None, halo=None, comm=None): # A cache of objects built on top of this set self._cache = {} - def __del__(self): - # Cannot use hasattr here, since child classes define `__getattr__` - # This causes infinite recursion when looked up! - if "comm" in self.__dict__: - mpi.decref(self.comm) - @utils.cached_property def core_size(self): """Core set size. Owned elements not touching halo elements.""" @@ -234,6 +230,7 @@ class GlobalSet(Set): def __init__(self, comm=None): self.comm = mpi.internal_comm(comm) + weakref.finalize(self, mpi.decref, self.comm) self._cache = {} @utils.cached_property @@ -319,6 +316,7 @@ class ExtrudedSet(Set): def __init__(self, parent, layers, extruded_periodic=False): self._parent = parent self.comm = mpi.internal_comm(parent.comm) + weakref.finalize(self, mpi.decref, self.comm) try: layers = utils.verify_reshape(layers, dtypes.IntType, (parent.total_size, 2)) self.constant_layers = False @@ -400,6 +398,7 @@ class Subset(ExtrudedSet): ('indices', (list, tuple, np.ndarray), TypeError)) def __init__(self, superset, indices): self.comm = mpi.internal_comm(superset.comm) + weakref.finalize(self, mpi.decref, self.comm) # sort and remove duplicates indices = np.unique(indices) @@ -544,12 +543,9 @@ def __init__(self, sets): "All components of a MixedSet must have the same number of layers." # TODO: do all sets need the same communicator? self.comm = mpi.internal_comm(functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.comm, sets))) + weakref.finalize(self, mpi.decref, self.comm) self._initialized = True - def __del__(self): - if self._initialized and hasattr(self, "comm"): - mpi.decref(self.comm) - @utils.cached_property def _kernel_args_(self): raise NotImplementedError From 560d5b99850c64bbc5d4cde0ce309860ee694846 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 16 Nov 2023 16:31:42 +0000 Subject: [PATCH 3/6] Remove dead code --- pyop2/types/glob.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index c80b7b2bc..89257f802 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -27,10 +27,6 @@ def __init__(self, dim, data=None, dtype=None, name=None): self._buf = np.empty(self.shape, dtype=self.dtype) self._name = name or "%s_#x%x" % (self.__class__.__name__.lower(), id(self)) - # ~ def __del__(self): # TODO !? - # ~ if hasattr(self, "comm"): - # ~ mpi.decref(self.comm) - @utils.cached_property def _kernel_args_(self): return (self._data.ctypes.data, ) From 8979f6afcf305a7c353a081c2dc9024dfbc80085 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 16 Nov 2023 16:51:12 +0000 Subject: [PATCH 4/6] A better idea --- pyop2/compilation.py | 7 ++----- pyop2/mpi.py | 14 +++++++++---- pyop2/parloop.py | 4 +--- pyop2/types/dat.py | 7 ++----- pyop2/types/dataset.py | 10 +++------ pyop2/types/glob.py | 4 +--- pyop2/types/map.py | 13 ++++-------- pyop2/types/mat.py | 43 ++++++++++++++++----------------------- pyop2/types/set.py | 22 ++++++++++---------- test/unit/test_caching.py | 4 +--- 10 files changed, 53 insertions(+), 75 deletions(-) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 249c67bdb..794024a8d 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -42,7 +42,6 @@ import shlex from hashlib import md5 from packaging.version import Version, InvalidVersion -import weakref from pyop2 import mpi @@ -189,10 +188,8 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co self._debug = configuration["debug"] # Compilation communicators are reference counted on the PyOP2 comm - self.pcomm = mpi.internal_comm(comm) - weakref.finalize(self, mpi.decref, self.pcomm) - self.comm = mpi.compilation_comm(self.pcomm) - weakref.finalize(self, mpi.decref, self.comm) + self.pcomm = mpi.internal_comm(comm, self) + self.comm = mpi.compilation_comm(self.pcomm, self) def __repr__(self): return f"<{self._name} compiler, version {self.version or 'unknown'}>" diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 2c2cd780e..d6e8ed713 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -267,8 +267,7 @@ class temp_internal_comm: """ def __init__(self, comm): self.user_comm = comm - self.internal_comm = internal_comm(self.user_comm) - weakref.finalize(self, decref, self.internal_comm) + self.internal_comm = internal_comm(self.user_comm, self) def __enter__(self): """ Returns an internal comm that will be safely decref'd @@ -282,10 +281,12 @@ def __exit__(self, exc_type, exc_value, traceback): pass -def internal_comm(comm): +def internal_comm(comm, obj): """ Creates an internal comm from the user comm. If comm is None, create an internal communicator from COMM_WORLD :arg comm: A communicator or None + :arg obj: The object which the comm is an attribute of + (usually `self`) :returns pyop2_comm: A PyOP2 internal communicator """ @@ -308,6 +309,7 @@ def internal_comm(comm): pyop2_comm = comm else: pyop2_comm = dup_comm(comm) + weakref.finalize(obj, decref, pyop2_comm) return pyop2_comm @@ -444,10 +446,13 @@ def set_compilation_comm(comm, comp_comm): @collective -def compilation_comm(comm): +def compilation_comm(comm, obj): """Get a communicator for compilation. :arg comm: The input communicator, must be a PyOP2 comm. + :arg obj: The object which the comm is an attribute of + (usually `self`) + :returns: A communicator used for compilation (may be smaller) """ if not is_pyop2_comm(comm): @@ -469,6 +474,7 @@ def compilation_comm(comm): else: comp_comm = comm incref(comp_comm) + weakref.finalize(obj, decref, comp_comm) return comp_comm diff --git a/pyop2/parloop.py b/pyop2/parloop.py index e3ae7aa85..cf96ba5b4 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -3,7 +3,6 @@ import operator from dataclasses import dataclass from typing import Any, Optional, Tuple -import weakref import loopy as lp import numpy as np @@ -152,8 +151,7 @@ def __init__(self, global_knl, iterset, arguments): self.global_kernel = global_knl self.iterset = iterset - self.comm = mpi.internal_comm(iterset.comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(iterset.comm, self) self.arguments, self.reduced_globals = self.prepare_reduced_globals(arguments, global_knl) @property diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 9f055e3b2..5ee339bcc 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -3,7 +3,6 @@ import ctypes import itertools import operator -import weakref import loopy as lp import numpy as np @@ -83,8 +82,7 @@ def __init__(self, dataset, data=None, dtype=None, name=None): EmptyDataMixin.__init__(self, data, dtype, self._shape) self._dataset = dataset - self.comm = mpi.internal_comm(dataset.comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(dataset.comm, self) self.halo_valid = True self._name = name or "dat_#x%x" % id(self) @@ -821,8 +819,7 @@ def what(x): if not all(d.dtype == self._dats[0].dtype for d in self._dats): raise ex.DataValueError('MixedDat with different dtypes is not supported') # TODO: Think about different communicators on dats (c.f. MixedSet) - self.comm = mpi.internal_comm(self._dats[0].comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(self._dats[0].comm, self) @property def dat_version(self): diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 870764934..5b02f27ca 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -1,5 +1,4 @@ import numbers -import weakref import numpy as np from petsc4py import PETSc @@ -30,8 +29,7 @@ def __init__(self, iter_set, dim=1, name=None): return if isinstance(iter_set, Subset): raise NotImplementedError("Deriving a DataSet from a Subset is unsupported") - self.comm = mpi.internal_comm(iter_set.comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(iter_set.comm, self) self._set = iter_set self._dim = utils.as_tuple(dim, numbers.Integral) self._cdim = np.prod(self._dim).item() @@ -207,8 +205,7 @@ def __init__(self, global_): if self._initialized: return self._global = global_ - self.comm = mpi.internal_comm(global_.comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(global_.comm, self) self._globalset = GlobalSet(comm=self.comm) self._name = "gdset_#x%x" % id(self) self._initialized = True @@ -378,8 +375,7 @@ def __init__(self, arg, dims=None): comm = self._process_args(arg, dims)[0][0].comm except AttributeError: comm = None - self.comm = mpi.internal_comm(comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(comm, self) self._initialized = True @classmethod diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index 89257f802..d8ed99134 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -2,7 +2,6 @@ import ctypes import operator import warnings -import weakref import numpy as np from petsc4py import PETSc @@ -244,8 +243,7 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None): super().__init__(dim, data, dtype, name) if comm is None: warnings.warn("PyOP2.Global has no comm, this is likely to break in parallel!") - self.comm = mpi.internal_comm(comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(comm, self) # Object versioning setup petsc_counter = (comm and self.dtype == PETSc.ScalarType) diff --git a/pyop2/types/map.py b/pyop2/types/map.py index 49d49e953..9d9ca48ae 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -1,7 +1,6 @@ import itertools import functools import numbers -import weakref import numpy as np @@ -37,8 +36,7 @@ class Map: def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, offset_quotient=None): self._iterset = iterset self._toset = toset - self.comm = mpi.internal_comm(toset.comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(toset.comm, self) self._arity = arity self._values = utils.verify_reshape(values, dtypes.IntType, (iterset.total_size, arity), allow_none=True) @@ -198,8 +196,7 @@ def __init__(self, map_, permutation): if isinstance(map_, ComposedMap): raise NotImplementedError("PermutedMap of ComposedMap not implemented: simply permute before composing") self.map_ = map_ - self.comm = mpi.internal_comm(map_.comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(map_.comm, self) self.permutation = np.asarray(permutation, dtype=Map.dtype) assert (np.unique(permutation) == np.arange(map_.arity, dtype=Map.dtype)).all() @@ -250,8 +247,7 @@ def __init__(self, *maps_, name=None): raise ex.MapTypeError("frommap.arity must be 1") self._iterset = maps_[-1].iterset self._toset = maps_[0].toset - self.comm = mpi.internal_comm(self._toset.comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(self._toset.comm, self) self._arity = maps_[0].arity # Don't call super().__init__() to avoid calling verify_reshape() self._values = None @@ -315,8 +311,7 @@ def __init__(self, maps): raise ex.MapTypeError("All maps needs to share a communicator") if len(comms) == 0: raise ex.MapTypeError("Don't know how to make communicator") - self.comm = mpi.internal_comm(comms[0]) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(comms[0], self) self._initialized = True @classmethod diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index baf5e5a2b..035fa19b1 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -1,7 +1,6 @@ import abc import ctypes import itertools -import weakref import numpy as np from petsc4py import PETSc @@ -69,15 +68,17 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, self._o_nnz = None self._nrows = None if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].toset.size self._ncols = None if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].toset.size - self.lcomm = mpi.internal_comm(dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm) - weakref.finalize(self, mpi.decref, self.lcomm) - self.rcomm = mpi.internal_comm(dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm) - weakref.finalize(self, mpi.decref, self.rcomm) + self.lcomm = mpi.internal_comm( + dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm, + self + ) + self.rcomm = mpi.internal_comm( + dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm, + self + ) else: - self.lcomm = mpi.internal_comm(self._rmaps[0].comm) - weakref.finalize(self, mpi.decref, self.lcomm) - self.rcomm = mpi.internal_comm(self._cmaps[0].comm) - weakref.finalize(self, mpi.decref, self.rcomm) + self.lcomm = mpi.internal_comm(self._rmaps[0].comm, self) + self.rcomm = mpi.internal_comm(self._cmaps[0].comm, self) rset, cset = self.dsets # All rmaps and cmaps have the same data set - just use the first. @@ -98,8 +99,7 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, if self.lcomm != self.rcomm: raise ValueError("Haven't thought hard enough about different left and right communicators") - self.comm = mpi.internal_comm(self.lcomm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(self.lcomm, self) self._name = name or "sparsity_#x%x" % id(self) self.iteration_regions = iteration_regions # If the Sparsity is defined on MixedDataSets, we need to build each @@ -381,13 +381,10 @@ def __init__(self, parent, i, j): self._dims = tuple([tuple([parent.dims[i][j]])]) self._blocks = [[self]] self.iteration_regions = parent.iteration_regions - self.lcomm = mpi.internal_comm(self.dsets[0].comm) - weakref.finalize(self, mpi.decref, self.lcomm) - self.rcomm = mpi.internal_comm(self.dsets[1].comm) - weakref.finalize(self, mpi.decref, self.rcomm) + self.lcomm = mpi.internal_comm(self.dsets[0].comm, self) + self.rcomm = mpi.internal_comm(self.dsets[1].comm, self) # TODO: think about lcomm != rcomm - self.comm = mpi.internal_comm(self.lcomm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(self.lcomm, self) self._initialized = True @classmethod @@ -446,12 +443,9 @@ class AbstractMat(DataCarrier, abc.ABC): ('name', str, ex.NameTypeError)) def __init__(self, sparsity, dtype=None, name=None): self._sparsity = sparsity - self.lcomm = mpi.internal_comm(sparsity.lcomm) - weakref.finalize(self, mpi.decref, self.lcomm) - self.rcomm = mpi.internal_comm(sparsity.rcomm) - weakref.finalize(self, mpi.decref, self.rcomm) - self.comm = mpi.internal_comm(sparsity.comm) - weakref.finalize(self, mpi.decref, self.comm) + self.lcomm = mpi.internal_comm(sparsity.lcomm, self) + self.rcomm = mpi.internal_comm(sparsity.rcomm, self) + self.comm = mpi.internal_comm(sparsity.comm, self) dtype = dtype or dtypes.ScalarType self._datatype = np.dtype(dtype) self._name = name or "mat_#x%x" % id(self) @@ -954,8 +948,7 @@ def __init__(self, parent, i, j): colis = cset.local_ises[j] self.handle = parent.handle.getLocalSubMatrix(isrow=rowis, iscol=colis) - self.comm = mpi.internal_comm(parent.comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(parent.comm, self) self.local_to_global_maps = self.handle.getLGMap() @property diff --git a/pyop2/types/set.py b/pyop2/types/set.py index c17cf4812..9a831aed6 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -1,7 +1,6 @@ import ctypes import functools import numbers -import weakref import numpy as np @@ -66,8 +65,7 @@ def _wrapper_cache_key_(self): @utils.validate_type(('size', (numbers.Integral, tuple, list, np.ndarray), ex.SizeTypeError), ('name', str, ex.NameTypeError)) def __init__(self, size, name=None, halo=None, comm=None): - self.comm = mpi.internal_comm(comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(comm, self) if isinstance(size, numbers.Integral): size = [size] * 3 size = utils.as_tuple(size, numbers.Integral, 3) @@ -229,8 +227,7 @@ class GlobalSet(Set): _argtypes_ = () def __init__(self, comm=None): - self.comm = mpi.internal_comm(comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(comm, self) self._cache = {} @utils.cached_property @@ -315,8 +312,7 @@ class ExtrudedSet(Set): @utils.validate_type(('parent', Set, TypeError)) def __init__(self, parent, layers, extruded_periodic=False): self._parent = parent - self.comm = mpi.internal_comm(parent.comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(parent.comm, self) try: layers = utils.verify_reshape(layers, dtypes.IntType, (parent.total_size, 2)) self.constant_layers = False @@ -397,8 +393,7 @@ class Subset(ExtrudedSet): @utils.validate_type(('superset', Set, TypeError), ('indices', (list, tuple, np.ndarray), TypeError)) def __init__(self, superset, indices): - self.comm = mpi.internal_comm(superset.comm) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm(superset.comm, self) # sort and remove duplicates indices = np.unique(indices) @@ -542,8 +537,13 @@ def __init__(self, sets): assert all(s is None or isinstance(s, GlobalSet) or ((s.layers == self._sets[0].layers).all() if s.layers is not None else True) for s in sets), \ "All components of a MixedSet must have the same number of layers." # TODO: do all sets need the same communicator? - self.comm = mpi.internal_comm(functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.comm, sets))) - weakref.finalize(self, mpi.decref, self.comm) + self.comm = mpi.internal_comm( + functools.reduce( + lambda a, b: a or b, + map(lambda s: s if s is None else s.comm, sets) + ), + self + ) self._initialized = True @utils.cached_property diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index 1c43ce52f..1b95abebc 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -536,8 +536,7 @@ def myfunc(arg): def collective_key(self, *args): """Return a cache key suitable for use when collective over a communicator.""" - # Explicitly `mpi.decref(self.comm)` in any test that uses this comm - self.comm = mpi.internal_comm(mpi.COMM_SELF) + self.comm = mpi.internal_comm(mpi.COMM_SELF, self) return self.comm, cachetools.keys.hashkey(*args) @pytest.fixture @@ -575,7 +574,6 @@ def test_decorator_collective_has_different_in_memory_key(self, cache, cachedir) assert obj1 == obj2 and obj1 is not obj2 assert len(cache) == 2 assert len(os.listdir(cachedir.name)) == 1 - mpi.decref(self.comm) def test_decorator_disk_cache_reuses_results(self, cache, cachedir): decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) From c7ad8bdbbd1349dd62142cb86d0715cefd16f8ce Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 23 Nov 2023 17:16:03 +0000 Subject: [PATCH 5/6] Add pytools requirements to do single_valued --- pyop2/types/set.py | 7 ++----- requirements-ext.txt | 1 + setup.py | 1 + 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pyop2/types/set.py b/pyop2/types/set.py index 9a831aed6..25abdf93c 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -1,8 +1,8 @@ import ctypes -import functools import numbers import numpy as np +import pytools from pyop2 import ( caching, @@ -538,10 +538,7 @@ def __init__(self, sets): "All components of a MixedSet must have the same number of layers." # TODO: do all sets need the same communicator? self.comm = mpi.internal_comm( - functools.reduce( - lambda a, b: a or b, - map(lambda s: s if s is None else s.comm, sets) - ), + pytools.single_valued(s.comm for s in sets if s is not None), self ) self._initialized = True diff --git a/requirements-ext.txt b/requirements-ext.txt index 0f19e0d06..2ccb04374 100644 --- a/requirements-ext.txt +++ b/requirements-ext.txt @@ -8,3 +8,4 @@ decorator<=4.4.2 dataclasses cachetools packaging +pytools diff --git a/setup.py b/setup.py index ad9f7815b..06c03e152 100644 --- a/setup.py +++ b/setup.py @@ -89,6 +89,7 @@ def get_petsc_dir(): 'decorator', 'mpi4py', 'numpy>=1.6', + 'pytools', ] version = sys.version_info[:2] From 1454e81dd0bad95d28fcc9cdb5ea61988af7e50f Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 12 Dec 2023 15:03:01 +0000 Subject: [PATCH 6/6] Add docstring for debug printing --- pyop2/mpi.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pyop2/mpi.py b/pyop2/mpi.py index d6e8ed713..04652ee03 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -479,6 +479,15 @@ def compilation_comm(comm, obj): def finalize_safe_debug(): + ''' Return function for debug output. + + When Python is finalizing the logging module may be finalized before we have + finished writing debug information. In this case we fall back to using the + Python `print` function to output debugging information. + + Furthermore, we always want to see this finalization information when + running the CI tests. + ''' if PYOP2_FINALIZED: if logger.level > DEBUG and not _running_on_ci: debug = lambda string: None