diff --git a/pyop2/types/map.py b/pyop2/types/map.py index 9d9ca48ae..81e386546 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -302,8 +302,6 @@ def __init__(self, maps): if self._initialized: return self._maps = maps - if not all(m is None or m.iterset == self.iterset for m in self._maps): - raise ex.MapTypeError("All maps in a MixedMap need to share the same iterset") # TODO: Think about different communicators on maps (c.f. MixedSet) # TODO: What if all maps are None? comms = tuple(m.comm for m in self._maps if m is not None) @@ -344,7 +342,11 @@ def split(self): @utils.cached_property def iterset(self): """:class:`MixedSet` mapped from.""" - return functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.iterset, self._maps)) + s, = set(m.iterset for m in self._maps) + if len(s) == 1: + return functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.iterset, self._maps)) + else: + raise RuntimeError("Found multiple itersets.") @utils.cached_property def toset(self): @@ -356,7 +358,11 @@ def toset(self): def arity(self): """Arity of the mapping: total number of toset elements mapped to per iterset element.""" - return sum(m.arity for m in self._maps) + s, = set(m.iterset for m in self._maps) + if len(s) == 1: + return sum(m.arity for m in self._maps) + else: + raise RuntimeError("Found multiple itersets.") @utils.cached_property def arities(self): @@ -402,7 +408,7 @@ def offset(self): @utils.cached_property def offset_quotient(self): """Offsets quotient.""" - raise NotImplementedError("offset_quotient not implemented for MixedMap") + return tuple(0 if m is None else m.offset_quotient for m in self._maps) def __iter__(self): r"""Yield all :class:`Map`\s when iterated over.""" diff --git a/pyop2/types/set.py b/pyop2/types/set.py index 25abdf93c..62123f482 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -78,6 +78,10 @@ def __init__(self, size, name=None, halo=None, comm=None): # A cache of objects built on top of this set self._cache = {} + def indices(self): + """Returns iterator.""" + return range(self.total_size) + @utils.cached_property def core_size(self): """Core set size. Owned elements not touching halo elements.""" diff --git a/test/unit/test_api.py b/test/unit/test_api.py index 066d4aa9b..468d17558 100644 --- a/test/unit/test_api.py +++ b/test/unit/test_api.py @@ -1446,11 +1446,6 @@ def test_mixed_map_split(self, maps): assert mmap.split[i] == m assert mmap.split[:-1] == tuple(mmap)[:-1] - def test_mixed_map_nonunique_itset(self, m_iterset_toset, m_set_toset): - "Map toset should be Set." - with pytest.raises(exceptions.MapTypeError): - op2.MixedMap((m_iterset_toset, m_set_toset)) - def test_mixed_map_iterset(self, mmap): "MixedMap iterset should return the common iterset of all Maps." for m in mmap: