Skip to content

Commit

Permalink
Merge branch 'pbrubeck/restrict-dual' into pbrubeck/demkowicz
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Aug 12, 2024
2 parents 780b3c7 + e1203ec commit ee8a8b9
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 70 deletions.
52 changes: 52 additions & 0 deletions FIAT/dual_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,58 @@ def to_riesz(self, poly_set):
mat[ells] += numpy.dot(dwts[alpha], dexpansion_values[alpha].T)
return mat

def get_indices(self, restriction_domain, take_closure=True):
"""Returns the list of dofs with support on a given restriction domain.
:arg restriction_domain: can be 'interior', 'vertex', 'edge', 'face' or 'facet'
:kwarg take_closure: Are we taking the closure of the restriction domain?
"""
entity_dofs = self.get_entity_ids()
if restriction_domain == "interior":
# Return dofs from interior, never taking the closure
indices = []
entities = entity_dofs[max(entity_dofs.keys())]
for (entity, ids) in sorted_by_key(entities):
indices.extend(ids)
return indices

# otherwise return dofs with d <= dim
if restriction_domain == "vertex":
dim = 0
elif restriction_domain == "edge":
dim = 1
elif restriction_domain == "face":
dim = 2
elif restriction_domain == "facet":
dim = self.get_reference_element().get_spatial_dimension() - 1
else:
raise RuntimeError("Invalid restriction domain")

is_prodcell = isinstance(max(entity_dofs.keys()), tuple)

ldim = 0 if take_closure else dim
indices = []
for d in range(ldim, dim + 1):
if is_prodcell:
for edim in entity_dofs:
if sum(edim) == d:
entities = entity_dofs[edim]
for (entity, ids) in sorted_by_key(entities):
indices.extend(ids)
else:
entities = entity_dofs[d]
for (entity, ids) in sorted_by_key(entities):
indices.extend(ids)
return indices


def sorted_by_key(mapping):
"Sort dict items by key, allowing different key types."
# Python3 doesn't allow comparing builtins of different type, therefore the typename trick here
def _key(x):
return (type(x[0]).__name__, x[0])
return sorted(mapping.items(), key=_key)


def make_entity_closure_ids(ref_el, entity_ids):
entity_closure_ids = {}
Expand Down
106 changes: 36 additions & 70 deletions FIAT/restricted.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,41 @@
from FIAT.finite_element import CiarletElement


class RestrictedDualSet(DualSet):
"""Restrict the given DualSet to the specified list of dofs."""

def __init__(self, dual, indices):
ref_el = dual.get_reference_element()
nodes_old = dual.get_nodes()
dof_counter = 0
entity_ids = {}
nodes = []
for d, entities in dual.get_entity_ids().items():
entity_ids[d] = {}
for entity, dofs in entities.items():
entity_ids[d][entity] = []
for dof in dofs:
if dof not in indices:
continue
entity_ids[d][entity].append(dof_counter)
dof_counter += 1
nodes.append(nodes_old[dof])
assert dof_counter == len(indices)
self._dual = dual
super(RestrictedDualSet, self).__init__(nodes, ref_el, entity_ids)

def get_indices(self, restriction_domain, take_closure=True):
"""Return the list of dofs with support on a given restriction domain.
:arg restriction_domain: can be 'interior', 'vertex', 'edge', 'face' or 'facet'
:kwarg take_closure: Are we taking the closure of the restriction domain?
"""
# Call get_indices on the parent class to support multiple restriction domains
return type(self._dual).get_indices(self, restriction_domain, take_closure=take_closure)


class RestrictedElement(CiarletElement):
"""Restrict given element to specified list of dofs."""
"""Restrict the given element to the specified list of dofs."""

def __init__(self, element, indices=None, restriction_domain=None, take_closure=True):
'''For sake of argument, indices overrides restriction_domain'''
Expand All @@ -18,7 +51,7 @@ def __init__(self, element, indices=None, restriction_domain=None, take_closure=
raise RuntimeError("Either indices or restriction_domain must be passed in")

if not indices:
indices = _get_indices(element, restriction_domain, take_closure)
indices = element.dual.get_indices(restriction_domain, take_closure=take_closure)

if isinstance(indices, str):
raise RuntimeError("variable 'indices' was a string; did you forget to use a keyword?")
Expand All @@ -29,29 +62,11 @@ def __init__(self, element, indices=None, restriction_domain=None, take_closure=
self._element = element
self._indices = indices

# Fetch reference element
ref_el = element.get_reference_element()

# Restrict primal set
poly_set = element.get_nodal_basis().take(indices)

# Restrict dual set
dof_counter = 0
entity_ids = {}
nodes = []
nodes_old = element.dual_basis()
for d, entities in element.entity_dofs().items():
entity_ids[d] = {}
for entity, dofs in entities.items():
entity_ids[d][entity] = []
for dof in dofs:
if dof not in indices:
continue
entity_ids[d][entity].append(dof_counter)
dof_counter += 1
nodes.append(nodes_old[dof])
assert dof_counter == len(indices)
dual = DualSet(nodes, ref_el, entity_ids)
dual = RestrictedDualSet(element.get_dual_set(), indices)

# Restrict mapping
mapping_old = element.mapping()
Expand All @@ -60,52 +75,3 @@ def __init__(self, element, indices=None, restriction_domain=None, take_closure=

# Call constructor of CiarletElement
super(RestrictedElement, self).__init__(poly_set, dual, 0, element.get_formdegree(), mapping_new[0])


def sorted_by_key(mapping):
"Sort dict items by key, allowing different key types."
# Python3 doesn't allow comparing builtins of different type, therefore the typename trick here
def _key(x):
return (type(x[0]).__name__, x[0])
return sorted(mapping.items(), key=_key)


def _get_indices(element, restriction_domain, take_closure):
"Restriction domain can be 'interior', 'vertex', 'edge', 'face' or 'facet'"

if restriction_domain == "interior":
# Return dofs from interior
return element.entity_dofs()[max(element.entity_dofs().keys())][0]

# otherwise return dofs with d <= dim
if restriction_domain == "vertex":
dim = 0
elif restriction_domain == "edge":
dim = 1
elif restriction_domain == "face":
dim = 2
elif restriction_domain == "facet":
dim = element.get_reference_element().get_spatial_dimension() - 1
else:
raise RuntimeError("Invalid restriction domain")

is_prodcell = isinstance(max(element.entity_dofs().keys()), tuple)

ldim = 0 if take_closure else dim
entity_dofs = element.entity_dofs()
indices = []
for d in range(ldim, dim + 1):
if is_prodcell:
for a in range(d + 1):
b = d - a
try:
entities = entity_dofs[(a, b)]
for (entity, index) in sorted_by_key(entities):
indices += index
except KeyError:
pass
else:
entities = entity_dofs[d]
for (entity, index) in sorted_by_key(entities):
indices += index
return indices

0 comments on commit ee8a8b9

Please sign in to comment.