From 1ea3ffa7e79e9954480596291f6431470216d11a Mon Sep 17 00:00:00 2001 From: Angus Gibson Date: Fri, 13 Dec 2024 11:36:33 +1100 Subject: [PATCH] Attempt to unify MixedElement with the new UFL interface Unfortunately this isn't quite there, as MixedElement has to deviate from FiniteElementBase in a few places: If a component is specified in either is_cellwise_constant() or degree(), a domain must also be provided so the component can be extracted using extract_component(). I've put this in as an optional keyword-only argument, but I don't think this is the best approach. --- finat/ufl/finiteelementbase.py | 2 +- finat/ufl/mixedelement.py | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/finat/ufl/finiteelementbase.py b/finat/ufl/finiteelementbase.py index 3a84f908..8057ab7d 100644 --- a/finat/ufl/finiteelementbase.py +++ b/finat/ufl/finiteelementbase.py @@ -143,7 +143,7 @@ def symmetry(self): # FIXME: different approach def _check_component(self, domain, i): """Check that component index i is valid.""" - sh = self.value_shape(domain.geometric_dimension()) + sh = self.reference_value_shape r = len(sh) if not (len(i) == r and all(j < k for (j, k) in zip(i, sh))): raise ValueError( diff --git a/finat/ufl/mixedelement.py b/finat/ufl/mixedelement.py index e1dd1ccd..520aa813 100644 --- a/finat/ufl/mixedelement.py +++ b/finat/ufl/mixedelement.py @@ -87,7 +87,7 @@ def reconstruct_from_elements(self, *elements): return self return MixedElement(*elements) - def symmetry(self, domain): + def symmetry(self): r"""Return the symmetry dict, which is a mapping :math:`c_0 \\to c_1`. meaning that component :math:`c_0` is represented by component @@ -99,7 +99,7 @@ def symmetry(self, domain): # Base index of the current subelement into mixed value j = 0 for e in self._sub_elements: - sh = e.value_shape(domain) + sh = e.reference_value_shape st = shape_to_strides(sh) # Map symmetries of subelement into index space of this # element @@ -109,7 +109,7 @@ def symmetry(self, domain): sm[(j0,)] = (j1,) # Update base index for next element j += product(sh) - if j != product(self.value_shape(domain)): + if j != product(self.reference_value_shape): raise ValueError("Size mismatch in symmetry algorithm.") return sm or {} @@ -142,17 +142,17 @@ def extract_subelement_component(self, domain, i): """ if isinstance(i, int): i = (i,) - self._check_component(i) + self._check_component(domain, i) # Select between indexing modes - if len(self.value_shape(domain)) == 1: + if len(self.reference_value_shape) == 1: # Indexing into a long vector of flattened subelement # shapes j, = i # Find subelement for this index for sub_element_index, e in enumerate(self._sub_elements): - sh = e.value_shape(domain) + sh = e.reference_value_shape si = product(sh) if j < si: break @@ -172,13 +172,13 @@ def extract_subelement_component(self, domain, i): component = i[1:] return (sub_element_index, component) - def extract_component(self, i): + def extract_component(self, domain, i): """Recursively extract component index relative to a (simple) element. and that element for given value component index. """ - sub_element_index, component = self.extract_subelement_component(i) - return self._sub_elements[sub_element_index].extract_component(component) + sub_element_index, component = self.extract_subelement_component(domain, i) + return self._sub_elements[sub_element_index].extract_component(domain, component) def extract_subelement_reference_component(self, i): """Extract direct subelement index and subelement relative. @@ -217,20 +217,20 @@ def extract_reference_component(self, i): sub_element_index, reference_component = self.extract_subelement_reference_component(i) return self._sub_elements[sub_element_index].extract_reference_component(reference_component) - def is_cellwise_constant(self, component=None): + def is_cellwise_constant(self, component=None, *, domain=None): """Return whether the basis functions of this element is spatially constant over each cell.""" if component is None: return all(e.is_cellwise_constant() for e in self.sub_elements) else: - i, e = self.extract_component(component) + i, e = self.extract_component(domain, component) return e.is_cellwise_constant() - def degree(self, component=None): + def degree(self, component=None, *, domain=None): """Return polynomial degree of finite element.""" if component is None: return self._degree # from FiniteElementBase, computed as max of subelements in __init__ else: - i, e = self.extract_component(component) + i, e = self.extract_component(domain, component) return e.degree() @property @@ -490,14 +490,14 @@ def flattened_sub_element_mapping(self): """Doc.""" return self._flattened_sub_element_mapping - def extract_subelement_component(self, i): + def extract_subelement_component(self, domain, i): """Extract direct subelement index and subelement relative. component index for a given component index. """ if isinstance(i, int): i = (i,) - self._check_component(i) + self._check_component(domain, i) i = self.symmetry().get(i, i) l = len(self._shape) # noqa: E741