From fa8d312f3c0ed9ede4c0706ac2802b17278a5895 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 11 Dec 2024 14:30:16 +0000 Subject: [PATCH] HDivTrace variants --- FIAT/hdiv_trace.py | 24 ++++++++++++------------ test/FIAT/unit/test_fiat.py | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/FIAT/hdiv_trace.py b/FIAT/hdiv_trace.py index a68509f79..264205a3e 100644 --- a/FIAT/hdiv_trace.py +++ b/FIAT/hdiv_trace.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import numpy as np +from FIAT.barycentric_interpolation import get_lagrange_points from FIAT.discontinuous_lagrange import DiscontinuousLagrange from FIAT.dual_set import DualSet from FIAT.finite_element import FiniteElement @@ -39,7 +40,7 @@ class HDivTrace(FiniteElement): arises in several DG formulations. """ - def __init__(self, ref_el, degree): + def __init__(self, ref_el, degree, variant=None): """Constructor for the HDivTrace element. :arg ref_el: A reference element, which may be a tensor product @@ -47,6 +48,7 @@ def __init__(self, ref_el, degree): :arg degree: The degree of approximation. If on a tensor product cell, then provide a tuple of degrees if you want varying degrees. + :arg variant: The point distribution variant passed on to recursivenodes. """ sd = ref_el.get_spatial_dimension() if sd in (0, 1): @@ -83,7 +85,7 @@ def __init__(self, ref_el, degree): # We have a facet entity! if cell.get_spatial_dimension() == facet_sd: - dg_elements[top_dim] = construct_dg_element(cell, degree) + dg_elements[top_dim] = construct_dg_element(cell, degree, variant) # Initialize for entity in entities: entity_dofs[top_dim][entity] = [] @@ -97,15 +99,14 @@ def __init__(self, ref_el, degree): nf = element.space_dimension() num_facets = len(topology[facet_dim]) + facet_pts = get_lagrange_points(element.dual_basis()) for i in range(num_facets): entity_dofs[facet_dim][i] = list(range(offset, offset + nf)) offset += nf # Run over nodes and collect the points for point evaluations - for dof in element.dual_basis(): - facet_pt, = dof.get_point_dict() - transform = ref_el.get_entity_transform(facet_dim, i) - pts.append(tuple(transform(facet_pt))) + transform = ref_el.get_entity_transform(facet_dim, i) + pts.extend(transform(facet_pts)) # Setting up dual basis - only point evaluations nodes = [PointEvaluation(ref_el, pt) for pt in pts] @@ -265,19 +266,18 @@ def is_nodal(): return True -def construct_dg_element(ref_el, degree): +def construct_dg_element(ref_el, degree, variant): """Constructs a discontinuous galerkin element of a given degree on a particular reference cell. """ if ref_el.get_shape() in [LINE, TRIANGLE]: - dg_element = DiscontinuousLagrange(ref_el, degree) + dg_element = DiscontinuousLagrange(ref_el, degree, variant) # Quadrilateral facets could be on a FiredrakeQuadrilateral. # In this case, we treat this as an interval x interval cell: elif ref_el.get_shape() == QUADRILATERAL: - dg_a = DiscontinuousLagrange(ufc_simplex(1), degree) - dg_b = DiscontinuousLagrange(ufc_simplex(1), degree) - dg_element = TensorProductElement(dg_a, dg_b) + dg_line = DiscontinuousLagrange(ufc_simplex(1), degree, variant) + dg_element = TensorProductElement(dg_line, dg_line) # This handles the more general case for facets: elif ref_el.get_shape() == TENSORPRODUCT: @@ -285,7 +285,7 @@ def construct_dg_element(ref_el, degree): "Must provide the same number of degrees as the number " "of cells that make up the tensor product cell." ) - sub_elements = [construct_dg_element(c, d) + sub_elements = [construct_dg_element(c, d, variant) for c, d in zip(ref_el.cells, degree) if c.get_shape() != POINT] diff --git a/test/FIAT/unit/test_fiat.py b/test/FIAT/unit/test_fiat.py index 34d5e333d..5a1dfb21b 100644 --- a/test/FIAT/unit/test_fiat.py +++ b/test/FIAT/unit/test_fiat.py @@ -552,11 +552,11 @@ def test_nodality_tabulate(element): "HDivTrace(T, 0)", "HDivTrace(T, 1)", "HDivTrace(T, 2)", - "HDivTrace(T, 3)", + "HDivTrace(T, 3, 'spectral')", "HDivTrace(S, 0)", "HDivTrace(S, 1)", "HDivTrace(S, 2)", - "HDivTrace(S, 3)", + "HDivTrace(S, 3, 'spectral')", ]) def test_facet_nodality_tabulate(element): """Check that certain elements (which do no implement get_nodal_basis)