Skip to content

Commit

Permalink
get interior muliindices
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Oct 14, 2023
1 parent c08dd7c commit 91844d4
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions FIAT/recursive_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
}
"""

def multiindex_equal(d, k):
def multiindex_equal(d, k, interior=0):
"""A generator for :math:`d`-tuple multi-indices whose sum is :math:`k`.
Args:
Expand All @@ -46,10 +46,10 @@ def multiindex_equal(d, k):
return
if k < 0:
return
for i in range(k):
for a in multiindex_equal(d-1, k-i):
for i in range(interior, k-interior):
for a in multiindex_equal(d-1, k-i, interior=interior):
yield (i,) + a
yield (k,) + (0,)*(d-1)
yield (k - (d-1)*interior,) + (interior,)*(d-1)


class NodeFamily:
Expand Down Expand Up @@ -80,7 +80,7 @@ def recursive(alpha, family):
if xn is None:
return b
if d == 2:
b[:] = xn[alpha]
b[:] = xn[list(alpha)]
return b
weight = 0.0
for i in range(d):
Expand All @@ -94,6 +94,7 @@ def recursive(alpha, family):
b /= weight
return b


def recursive_points(ref_el, order, rule="gll", interior=0):
if rule == "gll":
lr = quadrature.GaussLobattoLegendreQuadratureLineRule
Expand All @@ -103,15 +104,14 @@ def recursive_points(ref_el, order, rule="gll", interior=0):
raise ValueError("Unsupported quadrature rule %s" % rule)

line = reference_element.UFCInterval()
f = lambda n: numpy.array(lr(line, n+1).pts).flatten() if n>=1 else None
f = lambda n: numpy.array(lr(line, n + 1).pts).flatten() if n else None
family = NodeFamily(f)

verts = ref_el.vertices
tdim = len(verts) - 1
vs = numpy.array(verts)
affine_map = lambda x: numpy.dot(x, vs)
vertices = ref_el.vertices
X = numpy.array(vertices)
affine_map = lambda b: numpy.dot(b, X)
get_point = lambda alpha: tuple(affine_map(recursive(alpha, family)))
return list(map(get_point, multiindex_equal(tdim+1, order)))
return list(map(get_point, multiindex_equal(len(vertices), order, interior=interior)))


if __name__ == "__main__":
Expand All @@ -120,16 +120,16 @@ def recursive_points(ref_el, order, rule="gll", interior=0):
h = 0.5 * numpy.sqrt(3)
ref_el.vertices = [(0, h), (-1.0, -h), (1.0, -h)]

order = 5
order = 7
rule = "gll"
pts = recursive_points(ref_el, order, rule=rule)
x = []
y = []
for p in pts:
x.append(p[0])
y.append(p[1])

plt.scatter(x, y)
for interior in range(2):
pts = recursive_points(ref_el, order, rule=rule, interior=interior)
x = []
y = []
for p in pts:
x.append(p[0])
y.append(p[1])
plt.scatter(x, y)

plt.gca().set_aspect('equal', 'box')
plt.show()

0 comments on commit 91844d4

Please sign in to comment.