Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Oct 17, 2023
1 parent 418161f commit c9f3021
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 93 deletions.
31 changes: 19 additions & 12 deletions FIAT/gauss_legendre.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,44 @@
#
# Modified by Pablo D. Brubeck ([email protected]), 2021

from FIAT import finite_element, polynomial_set, dual_set, functional
from FIAT.reference_element import POINT, LINE, TRIANGLE, TETRAHEDRON
from FIAT import (finite_element, polynomial_set, dual_set, functional,
quadrature, recursive_points)
from FIAT.reference_element import POINT, LINE, TRIANGLE, TETRAHEDRON, UFCInterval
from FIAT.orientation_utils import make_entity_permutations_simplex
from FIAT.barycentric_interpolation import LagrangePolynomialSet
from FIAT.recursive_points import make_node_family, recursive_points


class GaussLegendrePointSet(recursive_points.RecursivePointSet):
"""Recursive point set on simplices based on the Gauss-Legendre points on
the interval"""
def __init__(self):
ref_el = UFCInterval()
lr = quadrature.GaussLegendreQuadratureLineRule
f = lambda n: lr(ref_el, n + 1).pts
super(GaussLegendrePointSet, self).__init__(f)


class GaussLegendreDualSet(dual_set.DualSet):
"""The dual basis for 1D discontinuous elements with nodes at the
"""The dual basis for discontinuous elements with nodes at the
(recursive) Gauss-Legendre points."""
node_family = make_node_family("gl")
point_set = GaussLegendrePointSet()

def __init__(self, ref_el, degree):
entity_ids = {}
entity_permutations = {}

# make nodes by getting points
# need to do this dimension-by-dimension, facet-by-facet
top = ref_el.get_topology()

for dim in sorted(top):
entity_ids[dim] = {}
entity_permutations[dim] = {}
perms = make_entity_permutations_simplex(dim, degree + 1 if dim == len(top) - 1 else -1)
for entity in sorted(top[dim]):
entity_ids[dim][entity] = []
entity_permutations[dim][entity] = perms
entity_permutations[dim][entity] = []

pts = recursive_points(self.node_family, ref_el.vertices, degree)
# make nodes by getting points
pts = self.point_set.recursive_points(ref_el.get_vertices(), degree)
nodes = [functional.PointEvaluation(ref_el, x) for x in pts]
entity_ids[dim][0] = list(range(len(nodes)))
entity_permutations[dim][0] = make_entity_permutations_simplex(dim, degree + 1)
super(GaussLegendreDualSet, self).__init__(nodes, ref_el, entity_ids, entity_permutations)


Expand Down
21 changes: 15 additions & 6 deletions FIAT/gauss_lobatto_legendre.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,26 @@
#
# Modified by Pablo D. Brubeck ([email protected]), 2021

from FIAT import finite_element, polynomial_set, dual_set, functional
from FIAT.reference_element import LINE, TRIANGLE, TETRAHEDRON
from FIAT import (finite_element, polynomial_set, dual_set, functional,
quadrature, recursive_points)
from FIAT.reference_element import LINE, TRIANGLE, TETRAHEDRON, UFCInterval
from FIAT.orientation_utils import make_entity_permutations_simplex
from FIAT.barycentric_interpolation import LagrangePolynomialSet
from FIAT.recursive_points import make_node_family, make_points


class GaussLobattoLegendrePointSet(recursive_points.RecursivePointSet):

def __init__(self):
ref_el = UFCInterval()
lr = quadrature.GaussLobattoLegendreQuadratureLineRule
f = lambda n: lr(ref_el, n + 1).pts if n else None
super(GaussLobattoLegendrePointSet, self).__init__(f)


class GaussLobattoLegendreDualSet(dual_set.DualSet):
"""The dual basis for simplex continuous elements with nodes at the
"""The dual basis for continuous elements with nodes at the
(recursive) Gauss-Lobatto points."""
node_family = make_node_family("gll")
point_set = GaussLobattoLegendrePointSet()

def __init__(self, ref_el, degree):
entity_ids = {}
Expand All @@ -35,7 +44,7 @@ def __init__(self, ref_el, degree):
entity_permutations[dim] = {}
perms = {0: [0]} if dim == 0 else make_entity_permutations_simplex(dim, degree - dim)
for entity in sorted(top[dim]):
pts_cur = make_points(self.node_family, ref_el, dim, entity, degree)
pts_cur = self.point_set.make_points(ref_el, dim, entity, degree)
nodes_cur = [functional.PointEvaluation(ref_el, x)
for x in pts_cur]
nnodes_cur = len(nodes_cur)
Expand Down
139 changes: 68 additions & 71 deletions FIAT/recursive_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#
# Written by Pablo D. Brubeck ([email protected]), 2023

from FIAT import quadrature, reference_element
import numpy

"""
Expand All @@ -23,108 +22,106 @@
"""


def multiindex_equal(d, k, interior=0):
"""A generator for :math:`d`-tuple multi-indices whose sum is :math:`k`.
def multiindex_equal(d, isum, imin=0):
"""A generator for d-tuple multi-indices whose sum is isum and minimum is imin.
"""
if d <= 0:
return
imin = interior
imax = k - (d-1) * imin
imax = isum - (d - 1) * imin
if imax < imin:
return
for i in range(imin, imax):
for a in multiindex_equal(d-1, k-i, interior=imin):
for a in multiindex_equal(d - 1, isum - i, imin=imin):
yield a + (i,)
yield (imin,)*(d-1) + (imax,)
yield (imin,) * (d - 1) + (imax,)


class NodeFamily:
"""Family of nodes on the unit interval. This class essentially is a
class RecursivePointSet(object):
"""Family of points on the unit interval. This class essentially is a
lazy-evaluate-and-cache dictionary: the user passes a routine to evaluate
entries for unknown keys """

def __init__(self, f):
self._f = f
self._cache = {}

def __getitem__(self, key):
def interval_points(self, degree):
try:
return self._cache[key]
return self._cache[degree]
except KeyError:
x = self._f(key)
x = self._f(degree)
if x is None:
x_ro = x
else:
x_ro = numpy.array(x).flatten()
x_ro.setflags(write=False)
return self._cache.setdefault(key, x_ro)

return self._cache.setdefault(degree, x_ro)

def _recursive(self, alpha):
"""The barycentric (d-1)-simplex coordinates for a
multiindex alpha of length d and sum n, based on a 1D node family."""
d = len(alpha)
n = sum(alpha)
b = numpy.zeros((d,), dtype="d")
xn = self.interval_points(n)
if xn is None:
return b
if d == 2:
b[:] = xn[list(alpha)]
return b
weight = 0.0
for i in range(d):
w = xn[n - alpha[i]]
alpha_noti = alpha[:i] + alpha[i+1:]
br = self._recursive(alpha_noti)
b[:i] += w * br[:i]
b[i+1:] += w * br[i:]
weight += w
b /= weight
return b

def make_node_family(family):
line = reference_element.UFCInterval()
def recursive_points(self, vertices, order, interior=0):
X = numpy.array(vertices)
get_point = lambda alpha: tuple(numpy.dot(self._recursive(alpha), X))
return list(map(get_point, multiindex_equal(len(vertices), order, interior)))

def make_points(self, ref_el, dim, entity_id, order):
"""Constructs a lattice of points on the entity_id:th
facet of dimension dim. Order indicates how many points to
include in each direction."""
if dim == 0:
return (ref_el.get_vertices()[entity_id], )
elif 0 < dim < ref_el.get_spatial_dimension():
entity_verts = \
ref_el.get_vertices_of_subcomplex(
ref_el.get_topology()[dim][entity_id])
return self.recursive_points(entity_verts, order, 1)
elif dim == ref_el.get_spatial_dimension():
return self.recursive_points(ref_el.get_vertices(), order, 1)
else:
raise ValueError("illegal dimension")


def make_recursive_point_set(family):
from FIAT import quadrature, reference_element
ref_el = reference_element.UFCInterval()
if family == "equispaced":
f = lambda n: numpy.linspace(0.0, 1.0, n + 1)
elif family == "dg_equispaced":
f = lambda n: numpy.linspace(1.0/(n+2.0), (n+1.0)/(n+2.0), n + 1)
elif family == "gl":
lr = quadrature.GaussLegendreQuadratureLineRule
f = lambda n: lr(line, n + 1).pts
f = lambda n: lr(ref_el, n + 1).pts
elif family == "gll":
lr = quadrature.GaussLobattoLegendreQuadratureLineRule
f = lambda n: lr(line, n + 1).pts if n else None
f = lambda n: lr(ref_el, n + 1).pts if n else None
else:
raise ValueError("Invalid node family %s" % family)
return NodeFamily(f)


def recursive(alpha, family):
"""The barycentric d-simplex coordinates for a
multiindex alpha with length n, based on a 1D node family."""
d = len(alpha)
n = sum(alpha)
b = numpy.zeros((d,), dtype="d")
xn = family[n]
if xn is None:
return b
if d == 2:
b[:] = xn[list(alpha)]
return b
weight = 0.0
for i in range(d):
w = xn[n - alpha[i]]
alpha_noti = alpha[:i] + alpha[i+1:]
br = recursive(alpha_noti, family)
b[:i] += w * br[:i]
b[i+1:] += w * br[i:]
weight += w
b /= weight
return b


def recursive_points(family, vertices, order, interior=0):
X = numpy.array(vertices)
get_point = lambda alpha: tuple(numpy.dot(recursive(alpha, family), X))
return list(map(get_point, multiindex_equal(len(vertices), order, interior=interior)))


def make_points(family, ref_el, dim, entity_id, order):
"""Constructs a lattice of points on the entity_id:th
facet of dimension dim. Order indicates how many points to
include in each direction."""
if dim == 0:
return (ref_el.get_vertices()[entity_id], )
elif 0 < dim < ref_el.get_spatial_dimension():
entity_verts = \
ref_el.get_vertices_of_subcomplex(
ref_el.get_topology()[dim][entity_id])
return recursive_points(family, entity_verts, order, interior=1)
elif dim == ref_el.get_spatial_dimension():
return recursive_points(family, ref_el.get_vertices(), order, interior=1)
else:
raise ValueError("illegal dimension")
return RecursivePointSet(f)


if __name__ == "__main__":
from FIAT import reference_element
from matplotlib import pyplot as plt
ref_el = reference_element.ufc_simplex(2)
h = numpy.sqrt(3)
Expand All @@ -140,16 +137,16 @@ def make_points(family, ref_el, dim, entity_id, order):
# rule = "equispaced"
# dg_rule = "dg_equispaced"

family = make_node_family(rule)
dg_family = make_node_family(dg_rule)
family = make_recursive_point_set(rule)
dg_family = make_recursive_point_set(dg_rule)

for d in range(1, 4):
print(make_points(family, reference_element.ufc_simplex(d), d, 0, d))
print(family.make_points(reference_element.ufc_simplex(d), d, 0, d))

topology = ref_el.get_topology()
for dim in topology:
for entity in topology[dim]:
pts = make_points(family, ref_el, dim, entity, order)
pts = family.make_points(ref_el, dim, entity, order)
if len(pts):
x = numpy.array(pts)
for r in range(1, 3):
Expand All @@ -171,7 +168,7 @@ def make_points(family, ref_el, dim, entity_id, order):
x0 = sum(x[:d])/d
plt.plot(x[:, 0], x[:, 1], "k")

pts = recursive_points(dg_family, ref_el.vertices, order)
pts = dg_family.recursive_points(ref_el.vertices, order)
x = numpy.array(pts)
for r in range(1, 3):
th = r * (2*numpy.pi)/3
Expand Down
3 changes: 1 addition & 2 deletions test/unit/test_gauss_legendre.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def test_gl_basis_values(dim, degree):
v = lambda x: sum(x)**test_degree
coefs = [n(v) for n in fe.dual.nodes]
integral = np.dot(coefs, np.dot(tab, q.wts))
reference = np.dot([sum(x)**test_degree
for x in q.pts], q.wts)
reference = np.dot([v(x) for x in q.pts], q.wts)
assert np.allclose(integral, reference, rtol=1e-14)


Expand Down
3 changes: 1 addition & 2 deletions test/unit/test_gauss_lobatto_legendre.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def test_gll_basis_values(dim, degree):
v = lambda x: sum(x)**test_degree
coefs = [n(v) for n in fe.dual.nodes]
integral = np.dot(coefs, np.dot(tab, q.wts))
reference = np.dot([sum(x)**test_degree
for x in q.pts], q.wts)
reference = np.dot([v(x) for x in q.pts], q.wts)
assert np.allclose(integral, reference, rtol=1e-14)


Expand Down

0 comments on commit c9f3021

Please sign in to comment.