Skip to content

Commit

Permalink
Merge pull request #64 from firedrakeproject/rckirby/feature/macro
Browse files Browse the repository at this point in the history
Implement C0/C1 macroelements
  • Loading branch information
rckirby authored May 1, 2024
2 parents dbc1c5d + 23ad19a commit d0bea63
Show file tree
Hide file tree
Showing 25 changed files with 1,712 additions and 457 deletions.
28 changes: 11 additions & 17 deletions FIAT/P0.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,28 @@
class P0Dual(dual_set.DualSet):
def __init__(self, ref_el):
entity_ids = {}
nodes = []
entity_permutations = {}
vs = numpy.array(ref_el.get_vertices())
if ref_el.get_dimension() == 0:
bary = ()
else:
bary = tuple(numpy.average(vs, 0))

nodes = [functional.PointEvaluation(ref_el, bary)]
entity_ids = {}
sd = ref_el.get_dimension()
top = ref_el.get_topology()
if sd == 0:
pts = [tuple() for entity in sorted(top[sd])]
else:
pts = [tuple(numpy.average(ref_el.get_vertices_of_subcomplex(top[sd][entity]), 0))
for entity in sorted(top[sd])]
nodes = [functional.PointEvaluation(ref_el, pt) for pt in pts]
for dim in sorted(top):
entity_ids[dim] = {}
entity_permutations[dim] = {}
sym_size = ref_el.symmetry_group_size(dim)
num_points = 1 if dim == sd else 0
if isinstance(dim, tuple):
assert isinstance(sym_size, tuple)
perms = {o: [] for o in numpy.ndindex(sym_size)}
perms = {o: list(range(num_points)) for o in numpy.ndindex(sym_size)}
else:
perms = {o: [] for o in range(sym_size)}
perms = {o: list(range(num_points)) for o in range(sym_size)}
for entity in sorted(top[dim]):
entity_ids[dim][entity] = []
entity_ids[dim][entity] = [entity] if dim == sd else []
entity_permutations[dim][entity] = perms
entity_ids[dim] = {0: [0]}
if isinstance(dim, tuple):
entity_permutations[dim][0] = {o: [0] for o in numpy.ndindex(sym_size)}
else:
entity_permutations[dim][0] = {o: [0] for o in range(sym_size)}

super(P0Dual, self).__init__(nodes, ref_el, entity_ids, entity_permutations)

Expand Down
6 changes: 5 additions & 1 deletion FIAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Import finite element classes
from FIAT.finite_element import FiniteElement, CiarletElement # noqa: F401
from FIAT.argyris import Argyris
from FIAT.hct import HsiehCloughTocher
from FIAT.bernstein import Bernstein
from FIAT.bell import Bell
from FIAT.argyris import QuinticArgyris
Expand All @@ -30,6 +31,7 @@
from FIAT.morley import Morley
from FIAT.nedelec import Nedelec
from FIAT.nedelec_second_kind import NedelecSecondKind
from FIAT.hierarchical import Legendre, IntegratedLegendre
from FIAT.P0 import P0
from FIAT.raviart_thomas import RaviartThomas
from FIAT.crouzeix_raviart import CrouzeixRaviart
Expand All @@ -48,7 +50,6 @@
from FIAT.restricted import RestrictedElement # noqa: F401
from FIAT.quadrature_element import QuadratureElement # noqa: F401
from FIAT.kong_mulder_veldhuizen import KongMulderVeldhuizen # noqa: F401
from FIAT.hierarchical import Legendre, IntegratedLegendre # noqa: F401
from FIAT.fdm_element import FDMLagrange, FDMDiscontinuousLagrange, FDMQuadrature, FDMBrokenH1, FDMBrokenL2, FDMHermite # noqa: F401

# Important functionality
Expand All @@ -61,6 +62,7 @@

# List of supported elements and mapping to element classes
supported_elements = {"Argyris": Argyris,
"HsiehCloughTocher": HsiehCloughTocher,
"Bell": Bell,
"Bernstein": Bernstein,
"Brezzi-Douglas-Marini": BrezziDouglasMarini,
Expand All @@ -81,6 +83,8 @@
"Gauss-Lobatto-Legendre": GaussLobattoLegendre,
"Gauss-Legendre": GaussLegendre,
"Gauss-Radau": GaussRadau,
"Legendre": Legendre,
"Integrated Legendre": IntegratedLegendre,
"Morley": Morley,
"Nedelec 1st kind H(curl)": Nedelec,
"Nedelec 2nd kind H(curl)": NedelecSecondKind,
Expand Down
111 changes: 65 additions & 46 deletions FIAT/barycentric_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,40 @@
from FIAT.functional import index_iterator


def get_lagrange_points(nodes):
"""Extract singleton point for each node."""
points = []
for node in nodes:
pt, = node.get_point_dict()
points.append(pt)
return points


def barycentric_interpolation(nodes, wts, dmat, pts, order=0):
"""Evaluates a Lagrange basis on a line reference element
via the second barycentric interpolation formula. See Berrut and Trefethen (2004)
https://doi.org/10.1137/S0036144502417715 Eq. (4.2) & (9.4)
"""
if pts.dtype == object:
from sympy import simplify
sp_simplify = numpy.vectorize(simplify)
else:
sp_simplify = lambda x: x
phi = numpy.add.outer(-nodes, pts.flatten())
with numpy.errstate(divide='ignore', invalid='ignore'):
numpy.reciprocal(phi, out=phi)
numpy.multiply(phi, wts[:, None], out=phi)
numpy.multiply(1.0 / numpy.sum(phi, axis=0), phi, out=phi)
phi[phi != phi] = 1.0

phi = sp_simplify(phi)
results = {(0,): phi}
for r in range(1, order+1):
phi = sp_simplify(numpy.dot(dmat, phi))
results[(r,)] = phi
return results


def make_dmat(x):
"""Returns Lagrange differentiation matrix and barycentric weights
associated with x[j]."""
Expand All @@ -24,83 +58,68 @@ def make_dmat(x):


class LagrangeLineExpansionSet(expansions.LineExpansionSet):
"""Evaluates a Lagrange basis on a line reference element
via the second barycentric interpolation formula. See Berrut and Trefethen (2004)
https://doi.org/10.1137/S0036144502417715 Eq. (4.2) & (9.4)
"""
"""Lagrange polynomial expansion set for given points the line."""
def __init__(self, ref_el, pts):
self.points = pts
self.x = numpy.array(pts).flatten()
self.dmat, self.weights = make_dmat(self.x)
self.x = numpy.array(pts, dtype="d").flatten()
self.cell_node_map = expansions.compute_cell_point_map(ref_el, pts, unique=False)
self.dmats = []
self.weights = []
self.nodes = []
for ibfs in self.cell_node_map:
nodes = self.x[ibfs]
dmat, wts = make_dmat(nodes)
self.dmats.append(dmat)
self.weights.append(wts)
self.nodes.append(nodes)

self.degree = max(len(wts) for wts in self.weights)-1
self.recurrence_order = self.degree + 1
super(LagrangeLineExpansionSet, self).__init__(ref_el)

def get_num_members(self, n):
return len(self.points)

def get_cell_node_map(self, n):
return self.cell_node_map

def get_points(self):
return self.points

def get_dmats(self, degree):
return [self.dmat.T]

def tabulate(self, n, pts):
assert n == len(self.points)-1
results = numpy.add.outer(-self.x, numpy.array(pts).flatten())
with numpy.errstate(divide='ignore', invalid='ignore'):
numpy.reciprocal(results, out=results)
numpy.multiply(results, self.weights[:, None], out=results)
numpy.multiply(1.0 / numpy.sum(results, axis=0), results, out=results)

results[results != results] = 1.0
if results.dtype == object:
from sympy import simplify
results = numpy.vectorize(simplify)(results)
return results

def _tabulate(self, n, pts, order=0):
vals = self.tabulate(n, pts)
results = [vals]
for r in range(order):
vals = numpy.dot(self.dmat, vals)
if vals.dtype == object:
from sympy import simplify
vals = numpy.vectorize(simplify)(vals)
results.append(vals)
for r in range(order+1):
shape = results[r].shape
shape = shape[:1] + (1,)*r + shape[1:]
results[r] = numpy.reshape(results[r], shape)
return results
def get_dmats(self, degree, cell=0):
return [self.dmats[cell].T]

def _tabulate_on_cell(self, n, pts, order=0, cell=0, direction=None):
return barycentric_interpolation(self.nodes[cell], self.weights[cell], self.dmats[cell], pts, order=order)


class LagrangePolynomialSet(polynomial_set.PolynomialSet):

def __init__(self, ref_el, pts, shape=tuple()):
degree = len(pts) - 1
if ref_el.get_shape() != reference_element.LINE:
raise ValueError("Invalid reference element type.")

expansion_set = LagrangeLineExpansionSet(ref_el, pts)
degree = expansion_set.degree
if shape == tuple():
num_components = 1
else:
flat_shape = numpy.ravel(shape)
num_components = numpy.prod(flat_shape)
num_exp_functions = expansions.polynomial_dimension(ref_el, degree)
num_exp_functions = expansion_set.get_num_members(degree)
num_members = num_components * num_exp_functions
embedded_degree = degree
if ref_el.get_shape() == reference_element.LINE:
expansion_set = LagrangeLineExpansionSet(ref_el, pts)
else:
raise ValueError("Invalid reference element type.")

# set up coefficients
if shape == tuple():
coeffs = numpy.eye(num_members)
coeffs = numpy.eye(num_members, dtype="d")
else:
coeffs_shape = (num_members, *shape, num_exp_functions)
coeffs = numpy.zeros(coeffs_shape, "d")
# use functional's index_iterator function
cur_bf = 0
for idx in index_iterator(shape):
n = expansions.polynomial_dimension(ref_el, embedded_degree)
for exp_bf in range(n):
for exp_bf in range(num_exp_functions):
cur_idx = (cur_bf, *idx, exp_bf)
coeffs[cur_idx] = 1.0
cur_bf += 1
Expand Down
65 changes: 65 additions & 0 deletions FIAT/check_format_variant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
import re

from FIAT.macro import AlfeldSplit, IsoSplit

# dicts mapping Lagrange variant names to recursivenodes family names
supported_cg_variants = {
"spectral": "gll",
"chebyshev": "lgc",
"equispaced": "equispaced",
"gll": "gll"}

supported_dg_variants = {
"spectral": "gl",
"chebyshev": "gc",
"equispaced": "equispaced",
"equispaced_interior": "equispaced_interior",
"gll": "gll",
"gl": "gl"}


def check_format_variant(variant, degree):
if variant is None:
Expand All @@ -20,3 +37,51 @@ def check_format_variant(variant, degree):
'or variant="integral(q)"')

return variant, interpolant_degree


def parse_lagrange_variant(variant, discontinuous=False, integral=False):
"""Parses variant options for Lagrange elements.
variant may be a single option or comma-separated pair
indicating the dof type (integral, equispaced, spectral, etc)
and the type of splitting to give a macro-element (Alfeld, iso)
"""
if variant is None:
variant = "integral" if integral else "equispaced"
options = variant.replace(" ", "").split(",")
assert len(options) <= 2

default = "integral" if integral else "spectral"
if integral:
supported_point_variants = {"integral": None}
elif discontinuous:
supported_point_variants = supported_dg_variants
else:
supported_point_variants = supported_cg_variants

# defaults
splitting = None
splitting_args = tuple()
point_variant = supported_point_variants[default]

for pre_opt in options:
opt = pre_opt.lower()
if opt == "alfeld":
splitting = AlfeldSplit
elif opt == "iso":
splitting = IsoSplit
elif opt.startswith("iso"):
match = re.match(r"^iso(?:\((\d+)\))?$", opt)
k, = match.groups()
call_split = IsoSplit
splitting_args = (int(k),)
elif opt in supported_point_variants:
point_variant = supported_point_variants[opt]
else:
raise ValueError("Illegal variant option")

if discontinuous and splitting is not None and point_variant in supported_cg_variants.values():
raise ValueError("Illegal variant. DG macroelements with DOFs on subcell boundaries are not unisolvent.")
if len(splitting_args) > 0:
splitting = lambda T: call_split(T, *splitting_args, point_variant or "gll")
return splitting, point_variant
Loading

0 comments on commit d0bea63

Please sign in to comment.