Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix extract_subelement_component #122

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions finat/ufl/mixedelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def reconstruct_from_elements(self, *elements):
return self
return MixedElement(*elements)

def symmetry(self, domain):
def symmetry(self, domain=None):
r"""Return the symmetry dict, which is a mapping :math:`c_0 \\to c_1`.

meaning that component :math:`c_0` is represented by component
Expand All @@ -103,7 +103,7 @@ def symmetry(self, domain):
st = shape_to_strides(sh)
# Map symmetries of subelement into index space of this
# element
for c0, c1 in e.symmetry().items():
for c0, c1 in e.symmetry(domain).items():
j0 = flatten_multiindex(c0, st) + j
j1 = flatten_multiindex(c1, st) + j
sm[(j0,)] = (j1,)
Expand Down Expand Up @@ -142,7 +142,7 @@ def extract_subelement_component(self, domain, i):
"""
if isinstance(i, int):
i = (i,)
self._check_component(i)
self._check_component(domain, i)

# Select between indexing modes
if len(self.value_shape(domain)) == 1:
Expand Down Expand Up @@ -172,12 +172,12 @@ def extract_subelement_component(self, domain, i):
component = i[1:]
return (sub_element_index, component)

def extract_component(self, i):
def extract_component(self, domain, i):
"""Recursively extract component index relative to a (simple) element.

and that element for given value component index.
"""
sub_element_index, component = self.extract_subelement_component(i)
sub_element_index, component = self.extract_subelement_component(domain, i)
return self._sub_elements[sub_element_index].extract_component(component)

def extract_subelement_reference_component(self, i):
Expand Down Expand Up @@ -217,20 +217,20 @@ def extract_reference_component(self, i):
sub_element_index, reference_component = self.extract_subelement_reference_component(i)
return self._sub_elements[sub_element_index].extract_reference_component(reference_component)

def is_cellwise_constant(self, component=None):
def is_cellwise_constant(self, component=None, domain=None):
"""Return whether the basis functions of this element is spatially constant over each cell."""
if component is None:
return all(e.is_cellwise_constant() for e in self.sub_elements)
else:
i, e = self.extract_component(component)
i, e = self.extract_component(domain, component)
return e.is_cellwise_constant()

def degree(self, component=None):
def degree(self, component=None, domain=None):
"""Return polynomial degree of finite element."""
if component is None:
return self._degree # from FiniteElementBase, computed as max of subelements in __init__
else:
i, e = self.extract_component(component)
i, e = self.extract_component(domain, component)
return e.degree()

@property
Expand Down Expand Up @@ -490,16 +490,16 @@ def flattened_sub_element_mapping(self):
"""Doc."""
return self._flattened_sub_element_mapping

def extract_subelement_component(self, i):
def extract_subelement_component(self, domain, i):
"""Extract direct subelement index and subelement relative.

component index for a given component index.
"""
if isinstance(i, int):
i = (i,)
self._check_component(i)
self._check_component(domain, i)

i = self.symmetry().get(i, i)
i = self.symmetry(domain).get(i, i)
l = len(self._shape) # noqa: E741
ii = i[:l]
jj = i[l:]
Expand All @@ -508,7 +508,7 @@ def extract_subelement_component(self, i):
k = self._sub_element_mapping[ii]
return (k, jj)

def symmetry(self):
def symmetry(self, domain=None):
r"""Return the symmetry dict, which is a mapping :math:`c_0 \\to c_1`.

meaning that component :math:`c_0` is represented by component
Expand Down
Loading