Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify MixedElement with the new UFL interface #119

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion finat/ufl/finiteelementbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def symmetry(self): # FIXME: different approach

def _check_component(self, domain, i):
"""Check that component index i is valid."""
sh = self.value_shape(domain.geometric_dimension())
Comment on lines 144 to -146

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this function still depend on the domain, or could it just be renamed to _check_reference_component?

sh = self.reference_value_shape
r = len(sh)
if not (len(i) == r and all(j < k for (j, k) in zip(i, sh))):
raise ValueError(
Expand Down
30 changes: 15 additions & 15 deletions finat/ufl/mixedelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def reconstruct_from_elements(self, *elements):
return self
return MixedElement(*elements)

def symmetry(self, domain):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this change. The symmetry should only be imposed on the reference components

def symmetry(self):
r"""Return the symmetry dict, which is a mapping :math:`c_0 \\to c_1`.

meaning that component :math:`c_0` is represented by component
Expand All @@ -99,7 +99,7 @@ def symmetry(self, domain):
# Base index of the current subelement into mixed value
j = 0
for e in self._sub_elements:
sh = e.value_shape(domain)
sh = e.reference_value_shape
st = shape_to_strides(sh)
# Map symmetries of subelement into index space of this
# element
Expand All @@ -109,7 +109,7 @@ def symmetry(self, domain):
sm[(j0,)] = (j1,)
# Update base index for next element
j += product(sh)
if j != product(self.value_shape(domain)):
if j != product(self.reference_value_shape):
raise ValueError("Size mismatch in symmetry algorithm.")
return sm or {}

Expand Down Expand Up @@ -142,17 +142,17 @@ def extract_subelement_component(self, domain, i):
"""
if isinstance(i, int):
i = (i,)
self._check_component(i)
self._check_component(domain, i)

# Select between indexing modes
if len(self.value_shape(domain)) == 1:
if len(self.reference_value_shape) == 1:
# Indexing into a long vector of flattened subelement
# shapes
j, = i

# Find subelement for this index
for sub_element_index, e in enumerate(self._sub_elements):
sh = e.value_shape(domain)
sh = e.reference_value_shape
si = product(sh)
if j < si:
break
Expand All @@ -172,13 +172,13 @@ def extract_subelement_component(self, domain, i):
component = i[1:]
return (sub_element_index, component)

def extract_component(self, i):
def extract_component(self, domain, i):
"""Recursively extract component index relative to a (simple) element.

and that element for given value component index.
"""
sub_element_index, component = self.extract_subelement_component(i)
return self._sub_elements[sub_element_index].extract_component(component)
sub_element_index, component = self.extract_subelement_component(domain, i)
return self._sub_elements[sub_element_index].extract_component(domain, component)

def extract_subelement_reference_component(self, i):
"""Extract direct subelement index and subelement relative.
Expand Down Expand Up @@ -217,20 +217,20 @@ def extract_reference_component(self, i):
sub_element_index, reference_component = self.extract_subelement_reference_component(i)
return self._sub_elements[sub_element_index].extract_reference_component(reference_component)

def is_cellwise_constant(self, component=None):
def is_cellwise_constant(self, component=None, *, domain=None):
"""Return whether the basis functions of this element is spatially constant over each cell."""
if component is None:
return all(e.is_cellwise_constant() for e in self.sub_elements)
else:
i, e = self.extract_component(component)
i, e = self.extract_component(domain, component)
return e.is_cellwise_constant()

def degree(self, component=None):
def degree(self, component=None, *, domain=None):
"""Return polynomial degree of finite element."""
if component is None:
return self._degree # from FiniteElementBase, computed as max of subelements in __init__
else:
i, e = self.extract_component(component)
i, e = self.extract_component(domain, component)
return e.degree()

@property
Expand Down Expand Up @@ -490,14 +490,14 @@ def flattened_sub_element_mapping(self):
"""Doc."""
return self._flattened_sub_element_mapping

def extract_subelement_component(self, i):
def extract_subelement_component(self, domain, i):
"""Extract direct subelement index and subelement relative.

component index for a given component index.
"""
if isinstance(i, int):
i = (i,)
self._check_component(i)
self._check_component(domain, i)

i = self.symmetry().get(i, i)
l = len(self._shape) # noqa: E741
Expand Down
Loading