Skip to content

Commit

Permalink
Merge pull request #324 from firedrakeproject/pbrubeck/hhj
Browse files Browse the repository at this point in the history
Refactor variants
  • Loading branch information
pbrubeck authored Nov 19, 2024
2 parents e06308d + 8d69f5c commit 587166f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 36 deletions.
78 changes: 42 additions & 36 deletions tsfc/finatinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# along with FFC. If not, see <http://www.gnu.org/licenses/>.

import weakref
from functools import partial, singledispatch
from functools import singledispatch

import FIAT
import finat
Expand Down Expand Up @@ -51,6 +51,7 @@
"Johnson-Mercier": finat.JohnsonMercier,
"Nonconforming Arnold-Winther": finat.ArnoldWintherNC,
"Conforming Arnold-Winther": finat.ArnoldWinther,
"Hu-Zhang": finat.HuZhang,
"Hermite": finat.Hermite,
"Kong-Mulder-Veldhuizen": finat.KongMulderVeldhuizen,
"Argyris": finat.Argyris,
Expand All @@ -74,6 +75,8 @@
"Nedelec 2nd kind H(curl)": finat.NedelecSecondKind,
"Raviart-Thomas": finat.RaviartThomas,
"Regge": finat.Regge,
"Gopalakrishnan-Lederer-Schoberl 1st kind": finat.GopalakrishnanLedererSchoberlFirstKind,
"Gopalakrishnan-Lederer-Schoberl 2nd kind": finat.GopalakrishnanLedererSchoberlSecondKind,
"BDMCE": finat.BrezziDouglasMariniCubeEdge,
"BDMCF": finat.BrezziDouglasMariniCubeFace,
# These require special treatment below
Expand Down Expand Up @@ -124,6 +127,23 @@ def convert(element, **kwargs):
raise ValueError("Unsupported element type %s" % type(element))


cg_interval_variants = {
"fdm": finat.FDMLagrange,
"fdm_ipdg": finat.FDMLagrange,
"fdm_quadrature": finat.FDMQuadrature,
"fdm_broken": finat.FDMBrokenH1,
"fdm_hermite": finat.FDMHermite,
}


dg_interval_variants = {
"fdm": finat.FDMDiscontinuousLagrange,
"fdm_quadrature": finat.FDMDiscontinuousLagrange,
"fdm_ipdg": lambda *args: finat.DiscontinuousElement(finat.FDMLagrange(*args)),
"fdm_broken": finat.FDMBrokenL2,
}


# Base finite elements first
@convert.register(finat.ufl.FiniteElement)
def convert_finiteelement(element, **kwargs):
Expand Down Expand Up @@ -152,30 +172,19 @@ def convert_finiteelement(element, **kwargs):
finat_elem, deps = _create_element(element, **kwargs)
return finat.FlattenedDimensions(finat_elem), deps

kw = {}
kind = element.variant()
if kind is None:
kind = 'spectral' # default variant
is_interval = element.cell.cellname() == 'interval'

if element.family() in {"Raviart-Thomas", "Nedelec 1st kind H(curl)",
"Brezzi-Douglas-Marini", "Nedelec 2nd kind H(curl)",
"Argyris"}:
lmbda = partial(lmbda, variant=element.variant())
elif element.family() == "Lagrange":
if element.family() == "Lagrange":
if kind == 'spectral':
lmbda = finat.GaussLobattoLegendre
elif kind.startswith('integral'):
lmbda = partial(finat.IntegratedLegendre, variant=kind)
elif kind in ['fdm', 'fdm_ipdg'] and is_interval:
lmbda = finat.FDMLagrange
elif kind == 'fdm_quadrature' and is_interval:
lmbda = finat.FDMQuadrature
elif kind == 'fdm_broken' and is_interval:
lmbda = finat.FDMBrokenH1
elif kind == 'fdm_hermite' and is_interval:
lmbda = finat.FDMHermite
elif kind in ['demkowicz', 'fdm']:
lmbda = partial(finat.IntegratedLegendre, variant=kind)
elif element.cell.cellname() == "interval" and kind in cg_interval_variants:
lmbda = cg_interval_variants[kind]
elif kind.startswith('integral') or kind in ['demkowicz', 'fdm']:
lmbda = finat.IntegratedLegendre
kw["variant"] = kind
elif kind in ['mgd', 'feec', 'qb', 'mse']:
degree = element.degree()
shift_axes = kwargs["shift_axes"]
Expand All @@ -184,20 +193,17 @@ def convert_finiteelement(element, **kwargs):
return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction), deps
else:
# Let FIAT handle the general case
lmbda = partial(finat.Lagrange, variant=kind)
lmbda = finat.Lagrange
kw["variant"] = kind

elif element.family() in ["Discontinuous Lagrange", "Discontinuous Lagrange L2"]:
if kind == 'spectral':
lmbda = finat.GaussLegendre
elif kind.startswith('integral'):
lmbda = partial(finat.Legendre, variant=kind)
elif kind in ['fdm', 'fdm_quadrature'] and is_interval:
lmbda = finat.FDMDiscontinuousLagrange
elif kind == 'fdm_ipdg' and is_interval:
lmbda = lambda *args: finat.DiscontinuousElement(finat.FDMLagrange(*args))
elif kind in 'fdm_broken' and is_interval:
lmbda = finat.FDMBrokenL2
elif kind in ['demkowicz', 'fdm']:
lmbda = partial(finat.Legendre, variant=kind)
elif element.cell.cellname() == "interval" and kind in dg_interval_variants:
lmbda = dg_interval_variants[kind]
elif kind.startswith('integral') or kind in ['demkowicz', 'fdm']:
lmbda = finat.Legendre
kw["variant"] = kind
elif kind in ['mgd', 'feec', 'qb', 'mse']:
degree = element.degree()
shift_axes = kwargs["shift_axes"]
Expand All @@ -206,13 +212,13 @@ def convert_finiteelement(element, **kwargs):
return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction, continuous=False), deps
else:
# Let FIAT handle the general case
lmbda = partial(finat.DiscontinuousLagrange, variant=kind)
elif element.family() == ["DPC", "DPC L2", "S"]:
dim = element.cell.geometric_dimension()
if dim > 1:
element = element.reconstruct(cell=ufl.cell.hypercube(dim))
lmbda = finat.DiscontinuousLagrange
kw["variant"] = kind

elif element.variant() is not None:
kw["variant"] = element.variant()

return lmbda(cell, element.degree()), set()
return lmbda(cell, element.degree(), **kw), set()


# Element modifiers and compound element types
Expand Down
11 changes: 11 additions & 0 deletions tsfc/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ def apply_mapping(expression, element, domain):
G(X) = det(J)^2 K g(x) K^T i.e. G_il(X)=(detJ)^2 K_ij g_jk K_lk
'covariant contravariant piola' mapping for g:
G(X) = det(J) J^T g(x) K^T i.e. G_il(X) = det(J) J_ji g_jk(x) K_lk
If 'contravariant piola' or 'covariant piola' (or their double
variants) are applied to a matrix-valued function, the appropriate
mappings are applied row-by-row.
Expand Down Expand Up @@ -443,6 +447,13 @@ def apply_mapping(expression, element, domain):
*k, i, j, m, n = indices(len(expression.ufl_shape) + 2)
kmn = (*k, m, n)
rexpression = as_tensor(detJ**2 * K[i, m] * expression[kmn] * K[j, n], (*k, i, j))
elif mapping == "covariant contravariant piola":
J = Jacobian(mesh)
K = JacobianInverse(mesh)
detJ = JacobianDeterminant(mesh)
*k, i, j, m, n = indices(len(expression.ufl_shape) + 2)
kmn = (*k, m, n)
rexpression = as_tensor(detJ * J[m, i] * expression[kmn] * K[j, n], (*k, i, j))
elif mapping == "symmetries":
# This tells us how to get from the pieces of the reference
# space expression to the physical space one.
Expand Down

0 comments on commit 587166f

Please sign in to comment.