Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

RestrictedFunctionSpace: Added in changes to Dataset / Set for use in RestrictedFunctionSpace in firedrake #716

Merged
merged 11 commits into from
Apr 26, 2024
12 changes: 6 additions & 6 deletions pyop2/types/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def lgmap(self):
indices for this :class:`DataSet`.
"""
lgmap = PETSc.LGMap()
if self.comm.size == 1:
if self.comm.size == 1 and self.halo is None:
lgmap.create(indices=np.arange(self.size, dtype=dtypes.IntType),
bsize=self.cdim, comm=self.comm)
else:
Expand Down Expand Up @@ -183,7 +183,7 @@ def local_ises(self):
def layout_vec(self):
"""A PETSc Vec compatible with the dof layout of this DataSet."""
vec = PETSc.Vec().create(comm=self.comm)
size = (self.size * self.cdim, None)
size = ((self.size - self.set.constrained_size) * self.cdim, None)
vec.setSizes(size, bsize=self.cdim)
vec.setUp()
return vec
Expand Down Expand Up @@ -449,8 +449,8 @@ def lgmap(self):
indices for this :class:`MixedDataSet`.
"""
lgmap = PETSc.LGMap()
if self.comm.size == 1:
size = sum(s.size * s.cdim for s in self)
if self.comm.size == 1 and self.halo is None:
size = sum((s.size - s.constrained_size) * s.cdim for s in self)
lgmap.create(indices=np.arange(size, dtype=dtypes.IntType),
bsize=1, comm=self.comm)
return lgmap
Expand Down Expand Up @@ -479,7 +479,7 @@ def lgmap(self):
# current field offset.
idx_size = sum(s.total_size*s.cdim for s in self)
indices = np.full(idx_size, -1, dtype=dtypes.IntType)
owned_sz = np.array([sum(s.size * s.cdim for s in self)],
owned_sz = np.array([sum((s.size - s.constrained_size) * s.cdim for s in self)],
dtype=dtypes.IntType)
field_offset = np.empty_like(owned_sz)
self.comm.Scan(owned_sz, field_offset)
Expand All @@ -493,7 +493,7 @@ def lgmap(self):
current_offsets = np.zeros(self.comm.size + 1, dtype=dtypes.IntType)
for s in self:
idx = indices[start:start + s.total_size * s.cdim]
owned_sz[0] = s.size * s.cdim
owned_sz[0] = (s.size - s.set.constrained_size) * s.cdim
self.comm.Scan(owned_sz, field_offset)
self.comm.Allgather(field_offset, current_offsets[1:])
# Find the ranks each entry in the l2g belongs to
Expand Down
13 changes: 12 additions & 1 deletion pyop2/types/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _wrapper_cache_key_(self):

@utils.validate_type(('size', (numbers.Integral, tuple, list, np.ndarray), ex.SizeTypeError),
('name', str, ex.NameTypeError))
def __init__(self, size, name=None, halo=None, comm=None):
def __init__(self, size, name=None, halo=None, comm=None, constrained_size=0):
self.comm = mpi.internal_comm(comm, self)
if isinstance(size, numbers.Integral):
size = [size] * 3
Expand All @@ -75,6 +75,8 @@ def __init__(self, size, name=None, halo=None, comm=None):
self._name = name or "set_#x%x" % id(self)
self._halo = halo
self._partition_size = 1024
self._constrained_size = constrained_size

# A cache of objects built on top of this set
self._cache = {}

Expand All @@ -88,6 +90,10 @@ def core_size(self):
"""Core set size. Owned elements not touching halo elements."""
return self._sizes[Set._CORE_SIZE]

@utils.cached_property
def constrained_size(self):
return self._constrained_size

@utils.cached_property
def size(self):
"""Set size, owned elements."""
Expand Down Expand Up @@ -588,6 +594,11 @@ def core_size(self):
"""Core set size. Owned elements not touching halo elements."""
return sum(s.core_size for s in self._sets)

@utils.cached_property
def constrained_size(self):
"""Set size, owned constrained elements."""
return sum(s.constrained_size for s in self._sets)

@utils.cached_property
def size(self):
"""Set size, owned elements."""
Expand Down
Loading