From 91844d4f3c70ba738a312d55d45e06f275e0e155 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 14 Oct 2023 07:20:31 +0100 Subject: [PATCH] get interior muliindices --- FIAT/recursive_points.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/FIAT/recursive_points.py b/FIAT/recursive_points.py index 62577926a..f55a17fa7 100644 --- a/FIAT/recursive_points.py +++ b/FIAT/recursive_points.py @@ -22,7 +22,7 @@ } """ -def multiindex_equal(d, k): +def multiindex_equal(d, k, interior=0): """A generator for :math:`d`-tuple multi-indices whose sum is :math:`k`. Args: @@ -46,10 +46,10 @@ def multiindex_equal(d, k): return if k < 0: return - for i in range(k): - for a in multiindex_equal(d-1, k-i): + for i in range(interior, k-interior): + for a in multiindex_equal(d-1, k-i, interior=interior): yield (i,) + a - yield (k,) + (0,)*(d-1) + yield (k - (d-1)*interior,) + (interior,)*(d-1) class NodeFamily: @@ -80,7 +80,7 @@ def recursive(alpha, family): if xn is None: return b if d == 2: - b[:] = xn[alpha] + b[:] = xn[list(alpha)] return b weight = 0.0 for i in range(d): @@ -94,6 +94,7 @@ def recursive(alpha, family): b /= weight return b + def recursive_points(ref_el, order, rule="gll", interior=0): if rule == "gll": lr = quadrature.GaussLobattoLegendreQuadratureLineRule @@ -103,15 +104,14 @@ def recursive_points(ref_el, order, rule="gll", interior=0): raise ValueError("Unsupported quadrature rule %s" % rule) line = reference_element.UFCInterval() - f = lambda n: numpy.array(lr(line, n+1).pts).flatten() if n>=1 else None + f = lambda n: numpy.array(lr(line, n + 1).pts).flatten() if n else None family = NodeFamily(f) - verts = ref_el.vertices - tdim = len(verts) - 1 - vs = numpy.array(verts) - affine_map = lambda x: numpy.dot(x, vs) + vertices = ref_el.vertices + X = numpy.array(vertices) + affine_map = lambda b: numpy.dot(b, X) get_point = lambda alpha: tuple(affine_map(recursive(alpha, family))) - return list(map(get_point, multiindex_equal(tdim+1, order))) + return list(map(get_point, multiindex_equal(len(vertices), order, interior=interior))) if __name__ == "__main__": @@ -120,16 +120,16 @@ def recursive_points(ref_el, order, rule="gll", interior=0): h = 0.5 * numpy.sqrt(3) ref_el.vertices = [(0, h), (-1.0, -h), (1.0, -h)] - order = 5 + order = 7 rule = "gll" - pts = recursive_points(ref_el, order, rule=rule) - x = [] - y = [] - for p in pts: - x.append(p[0]) - y.append(p[1]) - - plt.scatter(x, y) + for interior in range(2): + pts = recursive_points(ref_el, order, rule=rule, interior=interior) + x = [] + y = [] + for p in pts: + x.append(p[0]) + y.append(p[1]) + plt.scatter(x, y) plt.gca().set_aspect('equal', 'box') plt.show()