Skip to content

Commit

Permalink
HDivTrace variants
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Dec 11, 2024
1 parent 11dc133 commit fa8d312
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
24 changes: 12 additions & 12 deletions FIAT/hdiv_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import numpy as np
from FIAT.barycentric_interpolation import get_lagrange_points
from FIAT.discontinuous_lagrange import DiscontinuousLagrange
from FIAT.dual_set import DualSet
from FIAT.finite_element import FiniteElement
Expand Down Expand Up @@ -39,14 +40,15 @@ class HDivTrace(FiniteElement):
arises in several DG formulations.
"""

def __init__(self, ref_el, degree):
def __init__(self, ref_el, degree, variant=None):
"""Constructor for the HDivTrace element.
:arg ref_el: A reference element, which may be a tensor product
cell.
:arg degree: The degree of approximation. If on a tensor product
cell, then provide a tuple of degrees if you want
varying degrees.
:arg variant: The point distribution variant passed on to recursivenodes.
"""
sd = ref_el.get_spatial_dimension()
if sd in (0, 1):
Expand Down Expand Up @@ -83,7 +85,7 @@ def __init__(self, ref_el, degree):

# We have a facet entity!
if cell.get_spatial_dimension() == facet_sd:
dg_elements[top_dim] = construct_dg_element(cell, degree)
dg_elements[top_dim] = construct_dg_element(cell, degree, variant)
# Initialize
for entity in entities:
entity_dofs[top_dim][entity] = []
Expand All @@ -97,15 +99,14 @@ def __init__(self, ref_el, degree):
nf = element.space_dimension()
num_facets = len(topology[facet_dim])

facet_pts = get_lagrange_points(element.dual_basis())
for i in range(num_facets):
entity_dofs[facet_dim][i] = list(range(offset, offset + nf))
offset += nf

# Run over nodes and collect the points for point evaluations
for dof in element.dual_basis():
facet_pt, = dof.get_point_dict()
transform = ref_el.get_entity_transform(facet_dim, i)
pts.append(tuple(transform(facet_pt)))
transform = ref_el.get_entity_transform(facet_dim, i)
pts.extend(transform(facet_pts))

# Setting up dual basis - only point evaluations
nodes = [PointEvaluation(ref_el, pt) for pt in pts]
Expand Down Expand Up @@ -265,27 +266,26 @@ def is_nodal():
return True


def construct_dg_element(ref_el, degree):
def construct_dg_element(ref_el, degree, variant):
"""Constructs a discontinuous galerkin element of a given degree
on a particular reference cell.
"""
if ref_el.get_shape() in [LINE, TRIANGLE]:
dg_element = DiscontinuousLagrange(ref_el, degree)
dg_element = DiscontinuousLagrange(ref_el, degree, variant)

# Quadrilateral facets could be on a FiredrakeQuadrilateral.
# In this case, we treat this as an interval x interval cell:
elif ref_el.get_shape() == QUADRILATERAL:
dg_a = DiscontinuousLagrange(ufc_simplex(1), degree)
dg_b = DiscontinuousLagrange(ufc_simplex(1), degree)
dg_element = TensorProductElement(dg_a, dg_b)
dg_line = DiscontinuousLagrange(ufc_simplex(1), degree, variant)
dg_element = TensorProductElement(dg_line, dg_line)

# This handles the more general case for facets:
elif ref_el.get_shape() == TENSORPRODUCT:
assert len(degree) == len(ref_el.cells), (
"Must provide the same number of degrees as the number "
"of cells that make up the tensor product cell."
)
sub_elements = [construct_dg_element(c, d)
sub_elements = [construct_dg_element(c, d, variant)
for c, d in zip(ref_el.cells, degree)
if c.get_shape() != POINT]

Expand Down
4 changes: 2 additions & 2 deletions test/FIAT/unit/test_fiat.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,11 +552,11 @@ def test_nodality_tabulate(element):
"HDivTrace(T, 0)",
"HDivTrace(T, 1)",
"HDivTrace(T, 2)",
"HDivTrace(T, 3)",
"HDivTrace(T, 3, 'spectral')",
"HDivTrace(S, 0)",
"HDivTrace(S, 1)",
"HDivTrace(S, 2)",
"HDivTrace(S, 3)",
"HDivTrace(S, 3, 'spectral')",
])
def test_facet_nodality_tabulate(element):
"""Check that certain elements (which do no implement get_nodal_basis)
Expand Down

0 comments on commit fa8d312

Please sign in to comment.