forked from FEniCS/fiat
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
104 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,37 +8,44 @@ | |
# | ||
# Modified by Pablo D. Brubeck ([email protected]), 2021 | ||
|
||
from FIAT import finite_element, polynomial_set, dual_set, functional | ||
from FIAT.reference_element import POINT, LINE, TRIANGLE, TETRAHEDRON | ||
from FIAT import (finite_element, polynomial_set, dual_set, functional, | ||
quadrature, recursive_points) | ||
from FIAT.reference_element import POINT, LINE, TRIANGLE, TETRAHEDRON, UFCInterval | ||
from FIAT.orientation_utils import make_entity_permutations_simplex | ||
from FIAT.barycentric_interpolation import LagrangePolynomialSet | ||
from FIAT.recursive_points import make_node_family, recursive_points | ||
|
||
|
||
class GaussLegendrePointSet(recursive_points.RecursivePointSet): | ||
"""Recursive point set on simplices based on the Gauss-Legendre points on | ||
the interval""" | ||
def __init__(self): | ||
ref_el = UFCInterval() | ||
lr = quadrature.GaussLegendreQuadratureLineRule | ||
f = lambda n: lr(ref_el, n + 1).pts | ||
super(GaussLegendrePointSet, self).__init__(f) | ||
|
||
|
||
class GaussLegendreDualSet(dual_set.DualSet): | ||
"""The dual basis for 1D discontinuous elements with nodes at the | ||
"""The dual basis for discontinuous elements with nodes at the | ||
(recursive) Gauss-Legendre points.""" | ||
node_family = make_node_family("gl") | ||
point_set = GaussLegendrePointSet() | ||
|
||
def __init__(self, ref_el, degree): | ||
entity_ids = {} | ||
entity_permutations = {} | ||
|
||
# make nodes by getting points | ||
# need to do this dimension-by-dimension, facet-by-facet | ||
top = ref_el.get_topology() | ||
|
||
for dim in sorted(top): | ||
entity_ids[dim] = {} | ||
entity_permutations[dim] = {} | ||
perms = make_entity_permutations_simplex(dim, degree + 1 if dim == len(top) - 1 else -1) | ||
for entity in sorted(top[dim]): | ||
entity_ids[dim][entity] = [] | ||
entity_permutations[dim][entity] = perms | ||
entity_permutations[dim][entity] = [] | ||
|
||
pts = recursive_points(self.node_family, ref_el.vertices, degree) | ||
# make nodes by getting points | ||
pts = self.point_set.recursive_points(ref_el.get_vertices(), degree) | ||
nodes = [functional.PointEvaluation(ref_el, x) for x in pts] | ||
entity_ids[dim][0] = list(range(len(nodes))) | ||
entity_permutations[dim][0] = make_entity_permutations_simplex(dim, degree + 1) | ||
super(GaussLegendreDualSet, self).__init__(nodes, ref_el, entity_ids, entity_permutations) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,17 +8,26 @@ | |
# | ||
# Modified by Pablo D. Brubeck ([email protected]), 2021 | ||
|
||
from FIAT import finite_element, polynomial_set, dual_set, functional | ||
from FIAT.reference_element import LINE, TRIANGLE, TETRAHEDRON | ||
from FIAT import (finite_element, polynomial_set, dual_set, functional, | ||
quadrature, recursive_points) | ||
from FIAT.reference_element import LINE, TRIANGLE, TETRAHEDRON, UFCInterval | ||
from FIAT.orientation_utils import make_entity_permutations_simplex | ||
from FIAT.barycentric_interpolation import LagrangePolynomialSet | ||
from FIAT.recursive_points import make_node_family, make_points | ||
|
||
|
||
class GaussLobattoLegendrePointSet(recursive_points.RecursivePointSet): | ||
|
||
def __init__(self): | ||
ref_el = UFCInterval() | ||
lr = quadrature.GaussLobattoLegendreQuadratureLineRule | ||
f = lambda n: lr(ref_el, n + 1).pts if n else None | ||
super(GaussLobattoLegendrePointSet, self).__init__(f) | ||
|
||
|
||
class GaussLobattoLegendreDualSet(dual_set.DualSet): | ||
"""The dual basis for simplex continuous elements with nodes at the | ||
"""The dual basis for continuous elements with nodes at the | ||
(recursive) Gauss-Lobatto points.""" | ||
node_family = make_node_family("gll") | ||
point_set = GaussLobattoLegendrePointSet() | ||
|
||
def __init__(self, ref_el, degree): | ||
entity_ids = {} | ||
|
@@ -35,7 +44,7 @@ def __init__(self, ref_el, degree): | |
entity_permutations[dim] = {} | ||
perms = {0: [0]} if dim == 0 else make_entity_permutations_simplex(dim, degree - dim) | ||
for entity in sorted(top[dim]): | ||
pts_cur = make_points(self.node_family, ref_el, dim, entity, degree) | ||
pts_cur = self.point_set.make_points(ref_el, dim, entity, degree) | ||
nodes_cur = [functional.PointEvaluation(ref_el, x) | ||
for x in pts_cur] | ||
nnodes_cur = len(nodes_cur) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ | |
# | ||
# Written by Pablo D. Brubeck ([email protected]), 2023 | ||
|
||
from FIAT import quadrature, reference_element | ||
import numpy | ||
|
||
""" | ||
|
@@ -23,108 +22,106 @@ | |
""" | ||
|
||
|
||
def multiindex_equal(d, k, interior=0): | ||
"""A generator for :math:`d`-tuple multi-indices whose sum is :math:`k`. | ||
def multiindex_equal(d, isum, imin=0): | ||
"""A generator for d-tuple multi-indices whose sum is isum and minimum is imin. | ||
""" | ||
if d <= 0: | ||
return | ||
imin = interior | ||
imax = k - (d-1) * imin | ||
imax = isum - (d - 1) * imin | ||
if imax < imin: | ||
return | ||
for i in range(imin, imax): | ||
for a in multiindex_equal(d-1, k-i, interior=imin): | ||
for a in multiindex_equal(d - 1, isum - i, imin=imin): | ||
yield a + (i,) | ||
yield (imin,)*(d-1) + (imax,) | ||
yield (imin,) * (d - 1) + (imax,) | ||
|
||
|
||
class NodeFamily: | ||
"""Family of nodes on the unit interval. This class essentially is a | ||
class RecursivePointSet(object): | ||
"""Family of points on the unit interval. This class essentially is a | ||
lazy-evaluate-and-cache dictionary: the user passes a routine to evaluate | ||
entries for unknown keys """ | ||
|
||
def __init__(self, f): | ||
self._f = f | ||
self._cache = {} | ||
|
||
def __getitem__(self, key): | ||
def interval_points(self, degree): | ||
try: | ||
return self._cache[key] | ||
return self._cache[degree] | ||
except KeyError: | ||
x = self._f(key) | ||
x = self._f(degree) | ||
if x is None: | ||
x_ro = x | ||
else: | ||
x_ro = numpy.array(x).flatten() | ||
x_ro.setflags(write=False) | ||
return self._cache.setdefault(key, x_ro) | ||
|
||
return self._cache.setdefault(degree, x_ro) | ||
|
||
def _recursive(self, alpha): | ||
"""The barycentric (d-1)-simplex coordinates for a | ||
multiindex alpha of length d and sum n, based on a 1D node family.""" | ||
d = len(alpha) | ||
n = sum(alpha) | ||
b = numpy.zeros((d,), dtype="d") | ||
xn = self.interval_points(n) | ||
if xn is None: | ||
return b | ||
if d == 2: | ||
b[:] = xn[list(alpha)] | ||
return b | ||
weight = 0.0 | ||
for i in range(d): | ||
w = xn[n - alpha[i]] | ||
alpha_noti = alpha[:i] + alpha[i+1:] | ||
br = self._recursive(alpha_noti) | ||
b[:i] += w * br[:i] | ||
b[i+1:] += w * br[i:] | ||
weight += w | ||
b /= weight | ||
return b | ||
|
||
def make_node_family(family): | ||
line = reference_element.UFCInterval() | ||
def recursive_points(self, vertices, order, interior=0): | ||
X = numpy.array(vertices) | ||
get_point = lambda alpha: tuple(numpy.dot(self._recursive(alpha), X)) | ||
return list(map(get_point, multiindex_equal(len(vertices), order, interior))) | ||
|
||
def make_points(self, ref_el, dim, entity_id, order): | ||
"""Constructs a lattice of points on the entity_id:th | ||
facet of dimension dim. Order indicates how many points to | ||
include in each direction.""" | ||
if dim == 0: | ||
return (ref_el.get_vertices()[entity_id], ) | ||
elif 0 < dim < ref_el.get_spatial_dimension(): | ||
entity_verts = \ | ||
ref_el.get_vertices_of_subcomplex( | ||
ref_el.get_topology()[dim][entity_id]) | ||
return self.recursive_points(entity_verts, order, 1) | ||
elif dim == ref_el.get_spatial_dimension(): | ||
return self.recursive_points(ref_el.get_vertices(), order, 1) | ||
else: | ||
raise ValueError("illegal dimension") | ||
|
||
|
||
def make_recursive_point_set(family): | ||
from FIAT import quadrature, reference_element | ||
ref_el = reference_element.UFCInterval() | ||
if family == "equispaced": | ||
f = lambda n: numpy.linspace(0.0, 1.0, n + 1) | ||
elif family == "dg_equispaced": | ||
f = lambda n: numpy.linspace(1.0/(n+2.0), (n+1.0)/(n+2.0), n + 1) | ||
elif family == "gl": | ||
lr = quadrature.GaussLegendreQuadratureLineRule | ||
f = lambda n: lr(line, n + 1).pts | ||
f = lambda n: lr(ref_el, n + 1).pts | ||
elif family == "gll": | ||
lr = quadrature.GaussLobattoLegendreQuadratureLineRule | ||
f = lambda n: lr(line, n + 1).pts if n else None | ||
f = lambda n: lr(ref_el, n + 1).pts if n else None | ||
else: | ||
raise ValueError("Invalid node family %s" % family) | ||
return NodeFamily(f) | ||
|
||
|
||
def recursive(alpha, family): | ||
"""The barycentric d-simplex coordinates for a | ||
multiindex alpha with length n, based on a 1D node family.""" | ||
d = len(alpha) | ||
n = sum(alpha) | ||
b = numpy.zeros((d,), dtype="d") | ||
xn = family[n] | ||
if xn is None: | ||
return b | ||
if d == 2: | ||
b[:] = xn[list(alpha)] | ||
return b | ||
weight = 0.0 | ||
for i in range(d): | ||
w = xn[n - alpha[i]] | ||
alpha_noti = alpha[:i] + alpha[i+1:] | ||
br = recursive(alpha_noti, family) | ||
b[:i] += w * br[:i] | ||
b[i+1:] += w * br[i:] | ||
weight += w | ||
b /= weight | ||
return b | ||
|
||
|
||
def recursive_points(family, vertices, order, interior=0): | ||
X = numpy.array(vertices) | ||
get_point = lambda alpha: tuple(numpy.dot(recursive(alpha, family), X)) | ||
return list(map(get_point, multiindex_equal(len(vertices), order, interior=interior))) | ||
|
||
|
||
def make_points(family, ref_el, dim, entity_id, order): | ||
"""Constructs a lattice of points on the entity_id:th | ||
facet of dimension dim. Order indicates how many points to | ||
include in each direction.""" | ||
if dim == 0: | ||
return (ref_el.get_vertices()[entity_id], ) | ||
elif 0 < dim < ref_el.get_spatial_dimension(): | ||
entity_verts = \ | ||
ref_el.get_vertices_of_subcomplex( | ||
ref_el.get_topology()[dim][entity_id]) | ||
return recursive_points(family, entity_verts, order, interior=1) | ||
elif dim == ref_el.get_spatial_dimension(): | ||
return recursive_points(family, ref_el.get_vertices(), order, interior=1) | ||
else: | ||
raise ValueError("illegal dimension") | ||
return RecursivePointSet(f) | ||
|
||
|
||
if __name__ == "__main__": | ||
from FIAT import reference_element | ||
from matplotlib import pyplot as plt | ||
ref_el = reference_element.ufc_simplex(2) | ||
h = numpy.sqrt(3) | ||
|
@@ -140,16 +137,16 @@ def make_points(family, ref_el, dim, entity_id, order): | |
# rule = "equispaced" | ||
# dg_rule = "dg_equispaced" | ||
|
||
family = make_node_family(rule) | ||
dg_family = make_node_family(dg_rule) | ||
family = make_recursive_point_set(rule) | ||
dg_family = make_recursive_point_set(dg_rule) | ||
|
||
for d in range(1, 4): | ||
print(make_points(family, reference_element.ufc_simplex(d), d, 0, d)) | ||
print(family.make_points(reference_element.ufc_simplex(d), d, 0, d)) | ||
|
||
topology = ref_el.get_topology() | ||
for dim in topology: | ||
for entity in topology[dim]: | ||
pts = make_points(family, ref_el, dim, entity, order) | ||
pts = family.make_points(ref_el, dim, entity, order) | ||
if len(pts): | ||
x = numpy.array(pts) | ||
for r in range(1, 3): | ||
|
@@ -171,7 +168,7 @@ def make_points(family, ref_el, dim, entity_id, order): | |
x0 = sum(x[:d])/d | ||
plt.plot(x[:, 0], x[:, 1], "k") | ||
|
||
pts = recursive_points(dg_family, ref_el.vertices, order) | ||
pts = dg_family.recursive_points(ref_el.vertices, order) | ||
x = numpy.array(pts) | ||
for r in range(1, 3): | ||
th = r * (2*numpy.pi)/3 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters