diff --git a/finat/ufl/mixedelement.py b/finat/ufl/mixedelement.py index e1dd1ccd..8817fefc 100644 --- a/finat/ufl/mixedelement.py +++ b/finat/ufl/mixedelement.py @@ -87,7 +87,7 @@ def reconstruct_from_elements(self, *elements): return self return MixedElement(*elements) - def symmetry(self, domain): + def symmetry(self, domain=None): r"""Return the symmetry dict, which is a mapping :math:`c_0 \\to c_1`. meaning that component :math:`c_0` is represented by component @@ -103,7 +103,7 @@ def symmetry(self, domain): st = shape_to_strides(sh) # Map symmetries of subelement into index space of this # element - for c0, c1 in e.symmetry().items(): + for c0, c1 in e.symmetry(domain).items(): j0 = flatten_multiindex(c0, st) + j j1 = flatten_multiindex(c1, st) + j sm[(j0,)] = (j1,) @@ -142,7 +142,7 @@ def extract_subelement_component(self, domain, i): """ if isinstance(i, int): i = (i,) - self._check_component(i) + self._check_component(domain, i) # Select between indexing modes if len(self.value_shape(domain)) == 1: @@ -172,12 +172,12 @@ def extract_subelement_component(self, domain, i): component = i[1:] return (sub_element_index, component) - def extract_component(self, i): + def extract_component(self, domain, i): """Recursively extract component index relative to a (simple) element. and that element for given value component index. """ - sub_element_index, component = self.extract_subelement_component(i) + sub_element_index, component = self.extract_subelement_component(domain, i) return self._sub_elements[sub_element_index].extract_component(component) def extract_subelement_reference_component(self, i): @@ -217,20 +217,20 @@ def extract_reference_component(self, i): sub_element_index, reference_component = self.extract_subelement_reference_component(i) return self._sub_elements[sub_element_index].extract_reference_component(reference_component) - def is_cellwise_constant(self, component=None): + def is_cellwise_constant(self, component=None, domain=None): """Return whether the basis functions of this element is spatially constant over each cell.""" if component is None: return all(e.is_cellwise_constant() for e in self.sub_elements) else: - i, e = self.extract_component(component) + i, e = self.extract_component(domain, component) return e.is_cellwise_constant() - def degree(self, component=None): + def degree(self, component=None, domain=None): """Return polynomial degree of finite element.""" if component is None: return self._degree # from FiniteElementBase, computed as max of subelements in __init__ else: - i, e = self.extract_component(component) + i, e = self.extract_component(domain, component) return e.degree() @property @@ -490,16 +490,16 @@ def flattened_sub_element_mapping(self): """Doc.""" return self._flattened_sub_element_mapping - def extract_subelement_component(self, i): + def extract_subelement_component(self, domain, i): """Extract direct subelement index and subelement relative. component index for a given component index. """ if isinstance(i, int): i = (i,) - self._check_component(i) + self._check_component(domain, i) - i = self.symmetry().get(i, i) + i = self.symmetry(domain).get(i, i) l = len(self._shape) # noqa: E741 ii = i[:l] jj = i[l:] @@ -508,7 +508,7 @@ def extract_subelement_component(self, i): k = self._sub_element_mapping[ii] return (k, jj) - def symmetry(self): + def symmetry(self, domain=None): r"""Return the symmetry dict, which is a mapping :math:`c_0 \\to c_1`. meaning that component :math:`c_0` is represented by component