diff --git a/pyop2/parloop.py b/pyop2/parloop.py index 384576fa8..c70f4c9fb 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -121,7 +121,7 @@ def _kernel_args_(self): @property def map_kernel_args(self): rmap, cmap = self.maps - return tuple(itertools.chain(*itertools.product(rmap._kernel_args_, cmap._kernel_args_))) + return tuple(itertools.chain(rmap._kernel_args_, cmap._kernel_args_)) @dataclass @@ -143,7 +143,7 @@ def _kernel_args_(self): @property def map_kernel_args(self): rmap, cmap = self.maps - return tuple(itertools.chain(*itertools.product(rmap._kernel_args_, cmap._kernel_args_))) + return tuple(itertools.chain(rmap._kernel_args_, cmap._kernel_args_)) @dataclass diff --git a/pyop2/sparsity.pyx b/pyop2/sparsity.pyx index 131e91888..d6411feca 100644 --- a/pyop2/sparsity.pyx +++ b/pyop2/sparsity.pyx @@ -199,7 +199,7 @@ def fill_with_zeros(PETSc.Mat mat not None, dims, maps, iteration_regions, set_d PetscInt[:, ::1] rmap, cmap, tempmap PetscInt **rcomposedmaps = NULL PetscInt **ccomposedmaps = NULL - PetscInt nrcomposedmaps = 0, nccomposedmaps = 0, rset_entry, cset_entry + PetscInt nrcomposedmaps, nccomposedmaps, rset_entry, cset_entry PetscInt *rvals PetscInt *cvals PetscInt *roffset @@ -235,6 +235,7 @@ def fill_with_zeros(PETSc.Mat mat not None, dims, maps, iteration_regions, set_d else: rflags.append(set_writeable(pair[0])) # Memoryviews require writeable buffers rmap = pair[0].values_with_halo # Map values + nrcomposedmaps = 0 if isinstance(pair[1], op2.ComposedMap): m = pair[1].flattened_maps[0] cflags.append(set_writeable(m)) @@ -243,6 +244,7 @@ def fill_with_zeros(PETSc.Mat mat not None, dims, maps, iteration_regions, set_d else: cflags.append(set_writeable(pair[1])) cmap = pair[1].values_with_halo + nccomposedmaps = 0 # Handle ComposedMaps CHKERR(PetscMalloc2(nrcomposedmaps, &rcomposedmaps, nccomposedmaps, &ccomposedmaps)) for i in range(nrcomposedmaps): diff --git a/pyop2/types/map.py b/pyop2/types/map.py index 9d9ca48ae..81e386546 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -302,8 +302,6 @@ def __init__(self, maps): if self._initialized: return self._maps = maps - if not all(m is None or m.iterset == self.iterset for m in self._maps): - raise ex.MapTypeError("All maps in a MixedMap need to share the same iterset") # TODO: Think about different communicators on maps (c.f. MixedSet) # TODO: What if all maps are None? comms = tuple(m.comm for m in self._maps if m is not None) @@ -344,7 +342,11 @@ def split(self): @utils.cached_property def iterset(self): """:class:`MixedSet` mapped from.""" - return functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.iterset, self._maps)) + s, = set(m.iterset for m in self._maps) + if len(s) == 1: + return functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.iterset, self._maps)) + else: + raise RuntimeError("Found multiple itersets.") @utils.cached_property def toset(self): @@ -356,7 +358,11 @@ def toset(self): def arity(self): """Arity of the mapping: total number of toset elements mapped to per iterset element.""" - return sum(m.arity for m in self._maps) + s, = set(m.iterset for m in self._maps) + if len(s) == 1: + return sum(m.arity for m in self._maps) + else: + raise RuntimeError("Found multiple itersets.") @utils.cached_property def arities(self): @@ -402,7 +408,7 @@ def offset(self): @utils.cached_property def offset_quotient(self): """Offsets quotient.""" - raise NotImplementedError("offset_quotient not implemented for MixedMap") + return tuple(0 if m is None else m.offset_quotient for m in self._maps) def __iter__(self): r"""Yield all :class:`Map`\s when iterated over.""" diff --git a/pyop2/types/set.py b/pyop2/types/set.py index 25abdf93c..32fb01844 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -78,6 +78,11 @@ def __init__(self, size, name=None, halo=None, comm=None): # A cache of objects built on top of this set self._cache = {} + @property + def indices(self): + """Returns iterator.""" + return range(self.total_size) + @utils.cached_property def core_size(self): """Core set size. Owned elements not touching halo elements.""" diff --git a/test/unit/test_api.py b/test/unit/test_api.py index 066d4aa9b..468d17558 100644 --- a/test/unit/test_api.py +++ b/test/unit/test_api.py @@ -1446,11 +1446,6 @@ def test_mixed_map_split(self, maps): assert mmap.split[i] == m assert mmap.split[:-1] == tuple(mmap)[:-1] - def test_mixed_map_nonunique_itset(self, m_iterset_toset, m_set_toset): - "Map toset should be Set." - with pytest.raises(exceptions.MapTypeError): - op2.MixedMap((m_iterset_toset, m_set_toset)) - def test_mixed_map_iterset(self, mmap): "MixedMap iterset should return the common iterset of all Maps." for m in mmap: