Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Merge pull request #711 from OP2/JDBetteridge/RAII
Browse files Browse the repository at this point in the history
Remove __del__ method and add weakref.finalizer
  • Loading branch information
JDBetteridge authored Dec 12, 2023
2 parents 6d9fb0d + 1454e81 commit bef22bf
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 105 deletions.
7 changes: 0 additions & 7 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,6 @@ class ObjectCached(object):
details). The object on which the cache is stored should contain
a dict in its ``_cache`` attribute.
.. warning ::
This kind of cache sets up a circular reference. If either of
the objects implements ``__del__``, the Python garbage
collector will not be able to collect this cycle, and hence
the cache will never be evicted.
.. warning::
The derived class' :meth:`__init__` is still called if the
Expand Down
10 changes: 2 additions & 8 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,8 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co
self._debug = configuration["debug"]

# Compilation communicators are reference counted on the PyOP2 comm
self.pcomm = mpi.internal_comm(comm)
self.comm = mpi.compilation_comm(self.pcomm)

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)
if hasattr(self, "pcomm"):
mpi.decref(self.pcomm)
self.pcomm = mpi.internal_comm(comm, self)
self.comm = mpi.compilation_comm(self.pcomm, self)

def __repr__(self):
return f"<{self._name} compiler, version {self.version or 'unknown'}>"
Expand Down
26 changes: 20 additions & 6 deletions pyop2/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import glob
import os
import tempfile
import weakref

from pyop2.configuration import configuration
from pyop2.exceptions import CompilationError
Expand Down Expand Up @@ -266,10 +267,7 @@ class temp_internal_comm:
"""
def __init__(self, comm):
self.user_comm = comm
self.internal_comm = internal_comm(self.user_comm)

def __del__(self):
decref(self.internal_comm)
self.internal_comm = internal_comm(self.user_comm, self)

def __enter__(self):
""" Returns an internal comm that will be safely decref'd
Expand All @@ -283,10 +281,12 @@ def __exit__(self, exc_type, exc_value, traceback):
pass


def internal_comm(comm):
def internal_comm(comm, obj):
""" Creates an internal comm from the user comm.
If comm is None, create an internal communicator from COMM_WORLD
:arg comm: A communicator or None
:arg obj: The object which the comm is an attribute of
(usually `self`)
:returns pyop2_comm: A PyOP2 internal communicator
"""
Expand All @@ -309,6 +309,7 @@ def internal_comm(comm):
pyop2_comm = comm
else:
pyop2_comm = dup_comm(comm)
weakref.finalize(obj, decref, pyop2_comm)
return pyop2_comm


Expand Down Expand Up @@ -445,10 +446,13 @@ def set_compilation_comm(comm, comp_comm):


@collective
def compilation_comm(comm):
def compilation_comm(comm, obj):
"""Get a communicator for compilation.
:arg comm: The input communicator, must be a PyOP2 comm.
:arg obj: The object which the comm is an attribute of
(usually `self`)
:returns: A communicator used for compilation (may be smaller)
"""
if not is_pyop2_comm(comm):
Expand All @@ -470,10 +474,20 @@ def compilation_comm(comm):
else:
comp_comm = comm
incref(comp_comm)
weakref.finalize(obj, decref, comp_comm)
return comp_comm


def finalize_safe_debug():
''' Return function for debug output.
When Python is finalizing the logging module may be finalized before we have
finished writing debug information. In this case we fall back to using the
Python `print` function to output debugging information.
Furthermore, we always want to see this finalization information when
running the CI tests.
'''
if PYOP2_FINALIZED:
if logger.level > DEBUG and not _running_on_ci:
debug = lambda string: None
Expand Down
6 changes: 1 addition & 5 deletions pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,9 @@ def __init__(self, global_knl, iterset, arguments):

self.global_kernel = global_knl
self.iterset = iterset
self.comm = mpi.internal_comm(iterset.comm)
self.comm = mpi.internal_comm(iterset.comm, self)
self.arguments, self.reduced_globals = self.prepare_reduced_globals(arguments, global_knl)

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

@property
def local_kernel(self):
return self.global_kernel.local_kernel
Expand Down
8 changes: 2 additions & 6 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,13 @@ def __init__(self, dataset, data=None, dtype=None, name=None):
EmptyDataMixin.__init__(self, data, dtype, self._shape)

self._dataset = dataset
self.comm = mpi.internal_comm(dataset.comm)
self.comm = mpi.internal_comm(dataset.comm, self)
self.halo_valid = True
self._name = name or "dat_#x%x" % id(self)

self._halo_frozen = False
self._frozen_access_mode = None

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

@utils.cached_property
def _kernel_args_(self):
return (self._data.ctypes.data, )
Expand Down Expand Up @@ -823,7 +819,7 @@ def what(x):
if not all(d.dtype == self._dats[0].dtype for d in self._dats):
raise ex.DataValueError('MixedDat with different dtypes is not supported')
# TODO: Think about different communicators on dats (c.f. MixedSet)
self.comm = mpi.internal_comm(self._dats[0].comm)
self.comm = mpi.internal_comm(self._dats[0].comm, self)

@property
def dat_version(self):
Expand Down
12 changes: 3 additions & 9 deletions pyop2/types/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,13 @@ def __init__(self, iter_set, dim=1, name=None):
return
if isinstance(iter_set, Subset):
raise NotImplementedError("Deriving a DataSet from a Subset is unsupported")
self.comm = mpi.internal_comm(iter_set.comm)
self.comm = mpi.internal_comm(iter_set.comm, self)
self._set = iter_set
self._dim = utils.as_tuple(dim, numbers.Integral)
self._cdim = np.prod(self._dim).item()
self._name = name or "dset_#x%x" % id(self)
self._initialized = True

def __del__(self):
# Cannot use hasattr here, since we define `__getattr__`
# This causes infinite recursion when looked up!
if "comm" in self.__dict__:
mpi.decref(self.comm)

@classmethod
def _process_args(cls, *args, **kwargs):
return (args[0], ) + args, kwargs
Expand Down Expand Up @@ -211,7 +205,7 @@ def __init__(self, global_):
if self._initialized:
return
self._global = global_
self.comm = mpi.internal_comm(global_.comm)
self.comm = mpi.internal_comm(global_.comm, self)
self._globalset = GlobalSet(comm=self.comm)
self._name = "gdset_#x%x" % id(self)
self._initialized = True
Expand Down Expand Up @@ -381,7 +375,7 @@ def __init__(self, arg, dims=None):
comm = self._process_args(arg, dims)[0][0].comm
except AttributeError:
comm = None
self.comm = mpi.internal_comm(comm)
self.comm = mpi.internal_comm(comm, self)
self._initialized = True

@classmethod
Expand Down
10 changes: 1 addition & 9 deletions pyop2/types/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ def __init__(self, dim, data=None, dtype=None, name=None):
self._buf = np.empty(self.shape, dtype=self.dtype)
self._name = name or "%s_#x%x" % (self.__class__.__name__.lower(), id(self))

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

@utils.cached_property
def _kernel_args_(self):
return (self._data.ctypes.data, )
Expand Down Expand Up @@ -247,16 +243,12 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None):
super().__init__(dim, data, dtype, name)
if comm is None:
warnings.warn("PyOP2.Global has no comm, this is likely to break in parallel!")
self.comm = mpi.internal_comm(comm)
self.comm = mpi.internal_comm(comm, self)

# Object versioning setup
petsc_counter = (comm and self.dtype == PETSc.ScalarType)
VecAccessMixin.__init__(self, petsc_counter=petsc_counter)

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

def __str__(self):
return "OP2 Global Argument: %s with dim %s and value %s" \
% (self._name, self._dim, self._data)
Expand Down
12 changes: 4 additions & 8 deletions pyop2/types/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Map:
def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, offset_quotient=None):
self._iterset = iterset
self._toset = toset
self.comm = mpi.internal_comm(toset.comm)
self.comm = mpi.internal_comm(toset.comm, self)
self._arity = arity
self._values = utils.verify_reshape(values, dtypes.IntType,
(iterset.total_size, arity), allow_none=True)
Expand All @@ -53,10 +53,6 @@ def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, o
# A cache for objects built on top of this map
self._cache = {}

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

@utils.cached_property
def _kernel_args_(self):
return (self._values.ctypes.data, )
Expand Down Expand Up @@ -200,7 +196,7 @@ def __init__(self, map_, permutation):
if isinstance(map_, ComposedMap):
raise NotImplementedError("PermutedMap of ComposedMap not implemented: simply permute before composing")
self.map_ = map_
self.comm = mpi.internal_comm(map_.comm)
self.comm = mpi.internal_comm(map_.comm, self)
self.permutation = np.asarray(permutation, dtype=Map.dtype)
assert (np.unique(permutation) == np.arange(map_.arity, dtype=Map.dtype)).all()

Expand Down Expand Up @@ -251,7 +247,7 @@ def __init__(self, *maps_, name=None):
raise ex.MapTypeError("frommap.arity must be 1")
self._iterset = maps_[-1].iterset
self._toset = maps_[0].toset
self.comm = mpi.internal_comm(self._toset.comm)
self.comm = mpi.internal_comm(self._toset.comm, self)
self._arity = maps_[0].arity
# Don't call super().__init__() to avoid calling verify_reshape()
self._values = None
Expand Down Expand Up @@ -315,7 +311,7 @@ def __init__(self, maps):
raise ex.MapTypeError("All maps needs to share a communicator")
if len(comms) == 0:
raise ex.MapTypeError("Don't know how to make communicator")
self.comm = mpi.internal_comm(comms[0])
self.comm = mpi.internal_comm(comms[0], self)
self._initialized = True

@classmethod
Expand Down
46 changes: 18 additions & 28 deletions pyop2/types/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,17 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None,
self._o_nnz = None
self._nrows = None if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].toset.size
self._ncols = None if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].toset.size
self.lcomm = mpi.internal_comm(dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm)
self.rcomm = mpi.internal_comm(dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm)
self.lcomm = mpi.internal_comm(
dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm,
self
)
self.rcomm = mpi.internal_comm(
dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm,
self
)
else:
self.lcomm = mpi.internal_comm(self._rmaps[0].comm)
self.rcomm = mpi.internal_comm(self._cmaps[0].comm)
self.lcomm = mpi.internal_comm(self._rmaps[0].comm, self)
self.rcomm = mpi.internal_comm(self._cmaps[0].comm, self)

rset, cset = self.dsets
# All rmaps and cmaps have the same data set - just use the first.
Expand All @@ -93,7 +99,7 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None,

if self.lcomm != self.rcomm:
raise ValueError("Haven't thought hard enough about different left and right communicators")
self.comm = mpi.internal_comm(self.lcomm)
self.comm = mpi.internal_comm(self.lcomm, self)
self._name = name or "sparsity_#x%x" % id(self)
self.iteration_regions = iteration_regions
# If the Sparsity is defined on MixedDataSets, we need to build each
Expand Down Expand Up @@ -129,14 +135,6 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None,
self._blocks = [[self]]
self._initialized = True

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)
if hasattr(self, "lcomm"):
mpi.decref(self.lcomm)
if hasattr(self, "rcomm"):
mpi.decref(self.rcomm)

_cache = {}

@classmethod
Expand Down Expand Up @@ -383,10 +381,10 @@ def __init__(self, parent, i, j):
self._dims = tuple([tuple([parent.dims[i][j]])])
self._blocks = [[self]]
self.iteration_regions = parent.iteration_regions
self.lcomm = mpi.internal_comm(self.dsets[0].comm)
self.rcomm = mpi.internal_comm(self.dsets[1].comm)
self.lcomm = mpi.internal_comm(self.dsets[0].comm, self)
self.rcomm = mpi.internal_comm(self.dsets[1].comm, self)
# TODO: think about lcomm != rcomm
self.comm = mpi.internal_comm(self.lcomm)
self.comm = mpi.internal_comm(self.lcomm, self)
self._initialized = True

@classmethod
Expand Down Expand Up @@ -445,22 +443,14 @@ class AbstractMat(DataCarrier, abc.ABC):
('name', str, ex.NameTypeError))
def __init__(self, sparsity, dtype=None, name=None):
self._sparsity = sparsity
self.lcomm = mpi.internal_comm(sparsity.lcomm)
self.rcomm = mpi.internal_comm(sparsity.rcomm)
self.comm = mpi.internal_comm(sparsity.comm)
self.lcomm = mpi.internal_comm(sparsity.lcomm, self)
self.rcomm = mpi.internal_comm(sparsity.rcomm, self)
self.comm = mpi.internal_comm(sparsity.comm, self)
dtype = dtype or dtypes.ScalarType
self._datatype = np.dtype(dtype)
self._name = name or "mat_#x%x" % id(self)
self.assembly_state = Mat.ASSEMBLED

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)
if hasattr(self, "lcomm"):
mpi.decref(self.lcomm)
if hasattr(self, "rcomm"):
mpi.decref(self.rcomm)

@utils.validate_in(('access', _modes, ex.ModeValueError))
def __call__(self, access, path, lgmaps=None, unroll_map=False):
from pyop2.parloop import MatLegacyArg, MixedMatLegacyArg
Expand Down Expand Up @@ -958,7 +948,7 @@ def __init__(self, parent, i, j):
colis = cset.local_ises[j]
self.handle = parent.handle.getLocalSubMatrix(isrow=rowis,
iscol=colis)
self.comm = mpi.internal_comm(parent.comm)
self.comm = mpi.internal_comm(parent.comm, self)
self.local_to_global_maps = self.handle.getLGMap()

@property
Expand Down
Loading

0 comments on commit bef22bf

Please sign in to comment.