From f27c8b5d5fd9b79f18d7c610e889aa643b9ef9fe Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 14 Dec 2024 17:57:29 +0000 Subject: [PATCH 1/3] Single source of truth --- tsfc/finatinterface.py | 73 +++--------------------------------------- 1 file changed, 5 insertions(+), 68 deletions(-) diff --git a/tsfc/finatinterface.py b/tsfc/finatinterface.py index b7e3d0ad72..b4812fbd6e 100644 --- a/tsfc/finatinterface.py +++ b/tsfc/finatinterface.py @@ -31,79 +31,16 @@ "create_element", "supported_elements") -supported_elements = { - # These all map directly to FInAT elements - "Bernstein": finat.Bernstein, - "Bernardi-Raugel": finat.BernardiRaugel, - "Bernardi-Raugel Bubble": finat.BernardiRaugelBubble, - "Brezzi-Douglas-Marini": finat.BrezziDouglasMarini, - "Brezzi-Douglas-Fortin-Marini": finat.BrezziDouglasFortinMarini, - "Bubble": finat.Bubble, - "FacetBubble": finat.FacetBubble, - "Crouzeix-Raviart": finat.CrouzeixRaviart, - "Discontinuous Lagrange": finat.DiscontinuousLagrange, - "Discontinuous Raviart-Thomas": lambda c, d: finat.DiscontinuousElement(finat.RaviartThomas(c, d)), - "Discontinuous Taylor": finat.DiscontinuousTaylor, - "Gauss-Legendre": finat.GaussLegendre, - "Gauss-Lobatto-Legendre": finat.GaussLobattoLegendre, - "HDiv Trace": finat.HDivTrace, - "Hellan-Herrmann-Johnson": finat.HellanHerrmannJohnson, - "Johnson-Mercier": finat.JohnsonMercier, - "Nonconforming Arnold-Winther": finat.ArnoldWintherNC, - "Conforming Arnold-Winther": finat.ArnoldWinther, - "Hu-Zhang": finat.HuZhang, - "Hermite": finat.Hermite, - "Kong-Mulder-Veldhuizen": finat.KongMulderVeldhuizen, - "Argyris": finat.Argyris, - "Hsieh-Clough-Tocher": finat.HsiehCloughTocher, - "QuadraticPowellSabin6": finat.QuadraticPowellSabin6, - "QuadraticPowellSabin12": finat.QuadraticPowellSabin12, - "Reduced-Hsieh-Clough-Tocher": finat.ReducedHsiehCloughTocher, - "Mardal-Tai-Winther": finat.MardalTaiWinther, - "Alfeld-Sorokina": finat.AlfeldSorokina, - "Arnold-Qin": finat.ArnoldQin, - "Reduced-Arnold-Qin": finat.ReducedArnoldQin, - "Christiansen-Hu": finat.ChristiansenHu, - "Guzman-Neilan 1st kind H1": finat.GuzmanNeilanFirstKindH1, - "Guzman-Neilan 2nd kind H1": finat.GuzmanNeilanSecondKindH1, - "Guzman-Neilan Bubble": finat.GuzmanNeilanBubble, - "Guzman-Neilan H1(div)": finat.GuzmanNeilanH1div, - "Morley": finat.Morley, - "Bell": finat.Bell, - "Lagrange": finat.Lagrange, - "Nedelec 1st kind H(curl)": finat.Nedelec, - "Nedelec 2nd kind H(curl)": finat.NedelecSecondKind, - "Raviart-Thomas": finat.RaviartThomas, - "Regge": finat.Regge, - "Gopalakrishnan-Lederer-Schoberl 1st kind": finat.GopalakrishnanLedererSchoberlFirstKind, - "Gopalakrishnan-Lederer-Schoberl 2nd kind": finat.GopalakrishnanLedererSchoberlSecondKind, - "BDMCE": finat.BrezziDouglasMariniCubeEdge, - "BDMCF": finat.BrezziDouglasMariniCubeFace, - # These require special treatment below - "DQ": None, - "Q": None, - "RTCE": None, - "RTCF": None, - "NCE": None, - "NCF": None, - "Real": finat.Real, - "DPC": finat.DPC, - "S": finat.Serendipity, - "SminusF": finat.TrimmedSerendipityFace, - "SminusDiv": finat.TrimmedSerendipityDiv, - "SminusE": finat.TrimmedSerendipityEdge, - "SminusCurl": finat.TrimmedSerendipityCurl, - "DPC L2": finat.DPC, - "Discontinuous Lagrange L2": finat.DiscontinuousLagrange, - "Gauss-Legendre L2": finat.GaussLegendre, - "DQ L2": None, - "Direct Serendipity": finat.DirectSerendipity, -} +supported_elements = finat.supported_elements.copy() """A :class:`.dict` mapping UFL element family names to their FInAT-equivalent constructors. If the value is ``None``, the UFL element is supported, but must be handled specially because it doesn't have a direct FInAT equivalent.""" +# These require special treatment below +for tp_family in ("Q", "DQ", "DQ L2", "RTCE", "RTCF", "NCE", "NCF"): + supported_elements[tp_family] = None + def as_fiat_cell(cell): """Convert a ufl cell to a FIAT cell. From 044d8ec0da4b4b92fd2f5bc78180fab7a8748c87 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 19 Dec 2024 07:59:01 -0600 Subject: [PATCH 2/3] finat.element_factory --- docs/source/element_list.py | 2 +- firedrake/assemble.py | 2 +- firedrake/cython/dmcommon.pyx | 2 +- firedrake/cython/extrusion_numbering.pyx | 2 +- firedrake/extrusion_utils.py | 2 +- firedrake/functionspacedata.py | 2 +- firedrake/interpolation.py | 2 +- firedrake/mesh.py | 2 +- firedrake/mg/kernels.py | 4 +- firedrake/output/paraview_reordering.py | 2 +- firedrake/pointquery_utils.py | 2 +- firedrake/preconditioners/fdm.py | 2 +- firedrake/preconditioners/pmg.py | 2 +- firedrake/slate/slac/kernel_builder.py | 2 +- .../test_interpolate_p3intmoments.py | 4 +- tests/tsfc/test_create_fiat_element.py | 2 +- tests/tsfc/test_create_finat_element.py | 2 +- tests/tsfc/test_dual_evaluation.py | 2 +- .../tsfc/test_interpolation_factorisation.py | 2 +- tests/tsfc/test_tsfc_274.py | 2 +- tsfc/fem.py | 8 +- tsfc/finatinterface.py | 302 ------------------ tsfc/kernel_interface/common.py | 4 +- tsfc/kernel_interface/firedrake_loopy.py | 2 +- 24 files changed, 29 insertions(+), 331 deletions(-) delete mode 100644 tsfc/finatinterface.py diff --git a/docs/source/element_list.py b/docs/source/element_list.py index 2c2d048f01..6e8c6ab981 100644 --- a/docs/source/element_list.py +++ b/docs/source/element_list.py @@ -1,6 +1,6 @@ from finat.ufl.elementlist import ufl_elements # ~ from ufl.finiteelement.elementlist import ufl_elements -from tsfc.finatinterface import supported_elements +from finat.element_factory import supported_elements import csv shape_names = { diff --git a/firedrake/assemble.py b/firedrake/assemble.py index f451b3f596..f3049ae01c 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -12,7 +12,7 @@ import numpy from pyadjoint.tape import annotate_tape from tsfc import kernel_args -from tsfc.finatinterface import create_element +from finat.element_factory import create_element from tsfc.ufl_utils import extract_firedrake_constants import ufl import finat.ufl diff --git a/firedrake/cython/dmcommon.pyx b/firedrake/cython/dmcommon.pyx index eda8470d4e..75f980d7ba 100644 --- a/firedrake/cython/dmcommon.pyx +++ b/firedrake/cython/dmcommon.pyx @@ -11,7 +11,7 @@ from mpi4py import MPI from firedrake.utils import IntType, ScalarType from libc.string cimport memset from libc.stdlib cimport qsort -from tsfc.finatinterface import as_fiat_cell +from finat.element_factory import as_fiat_cell cimport numpy as np cimport mpi4py.MPI as MPI diff --git a/firedrake/cython/extrusion_numbering.pyx b/firedrake/cython/extrusion_numbering.pyx index 278af86c75..86167ed20b 100644 --- a/firedrake/cython/extrusion_numbering.pyx +++ b/firedrake/cython/extrusion_numbering.pyx @@ -193,7 +193,7 @@ from mpi4py.libmpi cimport (MPI_Op_create, MPI_OP_NULL, MPI_Op_free, MPI_User_function) from pyop2 import op2 from firedrake.utils import IntType -from tsfc.finatinterface import as_fiat_cell +from finat.element_factory import as_fiat_cell cimport numpy cimport mpi4py.MPI as MPI diff --git a/firedrake/extrusion_utils.py b/firedrake/extrusion_utils.py index b038d904af..4298a21d7b 100644 --- a/firedrake/extrusion_utils.py +++ b/firedrake/extrusion_utils.py @@ -8,7 +8,7 @@ from pyop2.caching import serial_cache from firedrake.petsc import PETSc from firedrake.utils import IntType, RealType, ScalarType -from tsfc.finatinterface import create_element +from finat.element_factory import create_element import loopy as lp from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa: F401 from firedrake.parameters import target diff --git a/firedrake/functionspacedata.py b/firedrake/functionspacedata.py index be00e3ffff..a1b6190cfb 100644 --- a/firedrake/functionspacedata.py +++ b/firedrake/functionspacedata.py @@ -20,7 +20,7 @@ from decorator import decorator from functools import partial -from tsfc.finatinterface import create_element as _create_element +from finat.element_factory import create_element as _create_element from pyop2 import op2 from firedrake.utils import IntType diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index e8555ba7dd..8a20e3da73 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -14,7 +14,7 @@ from pyop2 import op2 from pyop2.caching import memory_and_disk_cache -from tsfc.finatinterface import create_element, as_fiat_cell +from finat.element_factory import create_element, as_fiat_cell from tsfc import compile_expression_dual_evaluation from tsfc.ufl_utils import extract_firedrake_constants diff --git a/firedrake/mesh.py b/firedrake/mesh.py index bef4d38bf3..4e209340c7 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -45,7 +45,7 @@ ngsPETSc = None # Only for docstring import mpi4py # noqa: F401 -from tsfc.finatinterface import as_fiat_cell +from finat.element_factory import as_fiat_cell __all__ = [ diff --git a/firedrake/mg/kernels.py b/firedrake/mg/kernels.py index 5405b6d726..f892f6260c 100644 --- a/firedrake/mg/kernels.py +++ b/firedrake/mg/kernels.py @@ -30,7 +30,7 @@ from tsfc.driver import TSFCIntegralDataInfo from tsfc.kernel_interface.common import lower_integral_type from tsfc.parameters import default_parameters -from tsfc.finatinterface import create_element +from finat.element_factory import create_element from finat.quadrature import make_quadrature from firedrake.pointquery_utils import dX_norm_square, X_isub_dX, init_X, inside_check, is_affine, celldist_l1_c_expr from firedrake.pointquery_utils import to_reference_coords_newton_step as to_reference_coords_newton_step_body @@ -45,7 +45,7 @@ def to_reference_coordinates(ufl_coordinate_element, parameters=None): parameters = _ # Create FInAT element - element = tsfc.finatinterface.create_element(ufl_coordinate_element) + element = finat.element_factory.create_element(ufl_coordinate_element) gdim, = ufl_coordinate_element.reference_value_shape cell = ufl_coordinate_element.cell diff --git a/firedrake/output/paraview_reordering.py b/firedrake/output/paraview_reordering.py index 20fbe50099..8b6edb147d 100644 --- a/firedrake/output/paraview_reordering.py +++ b/firedrake/output/paraview_reordering.py @@ -1,4 +1,4 @@ -from tsfc.finatinterface import create_base_element +from finat.element_factory import create_base_element import numpy as np from pyop2.utils import as_tuple diff --git a/firedrake/pointquery_utils.py b/firedrake/pointquery_utils.py index 4793af0bfe..2308616622 100644 --- a/firedrake/pointquery_utils.py +++ b/firedrake/pointquery_utils.py @@ -221,7 +221,7 @@ def compile_coordinate_element(mesh: MeshGeometry, contains_eps: float, paramete ufl_coordinate_element = mesh.ufl_coordinate_element() # Create FInAT element - element = tsfc.finatinterface.create_element(ufl_coordinate_element) + element = finat.element_factory.create_element(ufl_coordinate_element) code = { "geometric_dimension": mesh.geometric_dimension(), diff --git a/firedrake/preconditioners/fdm.py b/firedrake/preconditioners/fdm.py index 1696801cb4..e75172dc8e 100644 --- a/firedrake/preconditioners/fdm.py +++ b/firedrake/preconditioners/fdm.py @@ -19,7 +19,7 @@ from firedrake_citations import Citations from ufl.algorithms.ad import expand_derivatives from ufl.algorithms.expand_indices import expand_indices -from tsfc.finatinterface import create_element +from finat.element_factory import create_element from pyop2.compilation import load from pyop2.mpi import COMM_SELF from pyop2.sparsity import get_preallocation diff --git a/firedrake/preconditioners/pmg.py b/firedrake/preconditioners/pmg.py index 251f585cbe..8726db6c9b 100644 --- a/firedrake/preconditioners/pmg.py +++ b/firedrake/preconditioners/pmg.py @@ -9,7 +9,7 @@ from firedrake.solving_utils import _SNESContext from firedrake.tsfc_interface import extract_numbered_coefficients from firedrake.utils import ScalarType_c, IntType_c, cached_property -from tsfc.finatinterface import create_element +from finat.element_factory import create_element from tsfc import compile_expression_dual_evaluation from pyop2 import op2 from pyop2.caching import serial_cache diff --git a/firedrake/slate/slac/kernel_builder.py b/firedrake/slate/slac/kernel_builder.py index cbc9b6fed2..419931232f 100644 --- a/firedrake/slate/slac/kernel_builder.py +++ b/firedrake/slate/slac/kernel_builder.py @@ -14,7 +14,7 @@ from firedrake.slate.slac.tsfc_driver import compile_terminal_form from tsfc import kernel_args -from tsfc.finatinterface import create_element +from finat.element_factory import create_element from tsfc.loopy import create_domains, assign_dtypes from pytools import UniqueNameGenerator diff --git a/tests/firedrake/regression/test_interpolate_p3intmoments.py b/tests/firedrake/regression/test_interpolate_p3intmoments.py index a56cf13ad9..6ef0ef34d2 100644 --- a/tests/firedrake/regression/test_interpolate_p3intmoments.py +++ b/tests/firedrake/regression/test_interpolate_p3intmoments.py @@ -9,7 +9,7 @@ from FIAT.quadrature import make_quadrature from FIAT.polynomial_set import ONPolynomialSet from finat.fiat_elements import ScalarFiatElement -from tsfc.finatinterface import convert, as_fiat_cell +from finat.element_factory import convert, as_fiat_cell import finat.ufl ufcint = UFCInterval() @@ -89,7 +89,7 @@ def __init__(self, cell, degree): super().__init__(P3IntMoments(cell, degree)) -# Replace the old tsfc.finatinterface.convert dispatch with a new one that +# Replace the old finat.element_factory.convert dispatch with a new one that # gives the the new FInAT element for P3 on an interval with variant # "interior-moment" old_convert = convert.dispatch(finat.ufl.FiniteElement) diff --git a/tests/tsfc/test_create_fiat_element.py b/tests/tsfc/test_create_fiat_element.py index f8a7d6efc4..f99053ab97 100644 --- a/tests/tsfc/test_create_fiat_element.py +++ b/tests/tsfc/test_create_fiat_element.py @@ -5,7 +5,7 @@ import ufl import finat.ufl -from tsfc.finatinterface import create_element as _create_element +from finat.element_factory import create_element as _create_element supported_elements = { diff --git a/tests/tsfc/test_create_finat_element.py b/tests/tsfc/test_create_finat_element.py index c0f34c292c..7964824c2c 100644 --- a/tests/tsfc/test_create_finat_element.py +++ b/tests/tsfc/test_create_finat_element.py @@ -3,7 +3,7 @@ import ufl import finat.ufl import finat -from tsfc.finatinterface import create_element, supported_elements +from finat.element_factory import create_element, supported_elements @pytest.fixture(params=["BDM", diff --git a/tests/tsfc/test_dual_evaluation.py b/tests/tsfc/test_dual_evaluation.py index b4f6e9770a..85f1617678 100644 --- a/tests/tsfc/test_dual_evaluation.py +++ b/tests/tsfc/test_dual_evaluation.py @@ -1,7 +1,7 @@ import pytest import ufl import finat.ufl -from tsfc.finatinterface import create_element +from finat.element_factory import create_element from tsfc import compile_expression_dual_evaluation diff --git a/tests/tsfc/test_interpolation_factorisation.py b/tests/tsfc/test_interpolation_factorisation.py index b3d4e3288b..4355c24b1f 100644 --- a/tests/tsfc/test_interpolation_factorisation.py +++ b/tests/tsfc/test_interpolation_factorisation.py @@ -7,7 +7,7 @@ from finat.ufl import FiniteElement, VectorElement, TensorElement from tsfc import compile_expression_dual_evaluation -from tsfc.finatinterface import create_element +from finat.element_factory import create_element @pytest.fixture(params=[interval, quadrilateral, hexahedron], diff --git a/tests/tsfc/test_tsfc_274.py b/tests/tsfc/test_tsfc_274.py index 453d8746e8..39da190496 100644 --- a/tests/tsfc/test_tsfc_274.py +++ b/tests/tsfc/test_tsfc_274.py @@ -2,7 +2,7 @@ import numpy from finat.point_set import PointSet from gem.interpreter import evaluate -from tsfc.finatinterface import create_element +from finat.element_factory import create_element from ufl import quadrilateral from finat.ufl import FiniteElement, RestrictedElement diff --git a/tsfc/fem.py b/tsfc/fem.py index 99251ed0c6..a5fa2009b1 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -15,6 +15,7 @@ PhysicalGeometry) from finat.point_set import PointSet, PointSingleton from finat.quadrature import make_quadrature +from finat.element_factory import as_fiat_cell, create_element from gem.node import traversal from gem.optimise import constant_fold_zero, ffc_rounding from gem.unconcatenate import unconcatenate @@ -32,7 +33,6 @@ from ufl.domain import extract_unique_domain from tsfc import ufl2gem -from tsfc.finatinterface import as_fiat_cell, create_element from tsfc.kernel_interface import ProxyKernelInterface from tsfc.modified_terminals import (analyse_modified_terminal, construct_modified_terminal) @@ -270,9 +270,10 @@ def get_quadrature_rule(fiat_cell, integration_dim, quadrature_degree, scheme): def make_basis_evaluation_key(ctx, finat_element, mt, entity_id): + ufl_element = mt.terminal.ufl_element() domain = extract_unique_domain(mt.terminal) coordinate_element = domain.ufl_coordinate_element() - return (finat_element, mt.local_derivatives, ctx.point_set, ctx.integration_dim, entity_id, coordinate_element, mt.restriction) + return (ufl_element, mt.local_derivatives, ctx.point_set, ctx.integration_dim, entity_id, coordinate_element, mt.restriction) class PointSetContext(ContextBase): @@ -697,8 +698,7 @@ def take_singleton(xs): for alpha, tables in per_derivative.items()} # Coefficient evaluation - ctx.index_cache.setdefault(terminal.ufl_element(), element.get_indices()) - beta = ctx.index_cache[terminal.ufl_element()] + beta = ctx.index_cache.setdefault(terminal.ufl_element(), element.get_indices()) zeta = element.get_value_indices() vec_beta, = gem.optimise.remove_componenttensors([gem.Indexed(vec, beta)]) value_dict = {} diff --git a/tsfc/finatinterface.py b/tsfc/finatinterface.py deleted file mode 100644 index b4812fbd6e..0000000000 --- a/tsfc/finatinterface.py +++ /dev/null @@ -1,302 +0,0 @@ -# This file was modified from FFC -# (http://bitbucket.org/fenics-project/ffc), copyright notice -# reproduced below. -# -# Copyright (C) 2009-2013 Kristian B. Oelgaard and Anders Logg -# -# This file is part of FFC. -# -# FFC is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# FFC is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with FFC. If not, see . - -import weakref -from functools import singledispatch - -import FIAT -import finat -import finat.ufl -import ufl - -__all__ = ("as_fiat_cell", "create_base_element", - "create_element", "supported_elements") - - -supported_elements = finat.supported_elements.copy() -"""A :class:`.dict` mapping UFL element family names to their -FInAT-equivalent constructors. If the value is ``None``, the UFL -element is supported, but must be handled specially because it doesn't -have a direct FInAT equivalent.""" - -# These require special treatment below -for tp_family in ("Q", "DQ", "DQ L2", "RTCE", "RTCF", "NCE", "NCF"): - supported_elements[tp_family] = None - - -def as_fiat_cell(cell): - """Convert a ufl cell to a FIAT cell. - - :arg cell: the :class:`ufl.Cell` to convert.""" - if not isinstance(cell, ufl.AbstractCell): - raise ValueError("Expecting a UFL Cell") - return FIAT.ufc_cell(cell) - - -@singledispatch -def convert(element, **kwargs): - """Handler for converting UFL elements to FInAT elements. - - :arg element: The UFL element to convert. - - Do not use this function directly, instead call - :func:`create_element`.""" - if element.family() in supported_elements: - raise ValueError("Element %s supported, but no handler provided" % element) - raise ValueError("Unsupported element type %s" % type(element)) - - -cg_interval_variants = { - "fdm": finat.FDMLagrange, - "fdm_ipdg": finat.FDMLagrange, - "fdm_quadrature": finat.FDMQuadrature, - "fdm_broken": finat.FDMBrokenH1, - "fdm_hermite": finat.FDMHermite, -} - - -dg_interval_variants = { - "fdm": finat.FDMDiscontinuousLagrange, - "fdm_quadrature": finat.FDMDiscontinuousLagrange, - "fdm_ipdg": lambda *args: finat.DiscontinuousElement(finat.FDMLagrange(*args)), - "fdm_broken": finat.FDMBrokenL2, -} - - -# Base finite elements first -@convert.register(finat.ufl.FiniteElement) -def convert_finiteelement(element, **kwargs): - cell = as_fiat_cell(element.cell) - if element.family() == "Quadrature": - degree = element.degree() - scheme = element.quadrature_scheme() - if degree is None or scheme is None: - raise ValueError("Quadrature scheme and degree must be specified!") - - return finat.make_quadrature_element(cell, degree, scheme), set() - lmbda = supported_elements[element.family()] - if element.family() == "Real" and element.cell.cellname() in {"quadrilateral", "hexahedron"}: - lmbda = None - element = finat.ufl.FiniteElement("DQ", element.cell, 0) - if lmbda is None: - if element.cell.cellname() == "quadrilateral": - # Handle quadrilateral short names like RTCF and RTCE. - element = element.reconstruct(cell=quadrilateral_tpc) - elif element.cell.cellname() == "hexahedron": - # Handle hexahedron short names like NCF and NCE. - element = element.reconstruct(cell=hexahedron_tpc) - else: - raise ValueError("%s is supported, but handled incorrectly" % - element.family()) - finat_elem, deps = _create_element(element, **kwargs) - return finat.FlattenedDimensions(finat_elem), deps - - finat_kwargs = {} - kind = element.variant() - if kind is None: - kind = 'spectral' # default variant - - if element.family() == "Lagrange": - if kind == 'spectral': - lmbda = finat.GaussLobattoLegendre - elif element.cell.cellname() == "interval" and kind in cg_interval_variants: - lmbda = cg_interval_variants[kind] - elif any(map(kind.startswith, ['integral', 'demkowicz', 'fdm'])): - lmbda = finat.IntegratedLegendre - finat_kwargs["variant"] = kind - elif kind in ['mgd', 'feec', 'qb', 'mse']: - degree = element.degree() - shift_axes = kwargs["shift_axes"] - restriction = kwargs["restriction"] - deps = {"shift_axes", "restriction"} - return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction), deps - else: - # Let FIAT handle the general case - lmbda = finat.Lagrange - finat_kwargs["variant"] = kind - - elif element.family() in ["Discontinuous Lagrange", "Discontinuous Lagrange L2"]: - if kind == 'spectral': - lmbda = finat.GaussLegendre - elif element.cell.cellname() == "interval" and kind in dg_interval_variants: - lmbda = dg_interval_variants[kind] - elif any(map(kind.startswith, ['integral', 'demkowicz', 'fdm'])): - lmbda = finat.Legendre - finat_kwargs["variant"] = kind - elif kind in ['mgd', 'feec', 'qb', 'mse']: - degree = element.degree() - shift_axes = kwargs["shift_axes"] - restriction = kwargs["restriction"] - deps = {"shift_axes", "restriction"} - return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction, continuous=False), deps - else: - # Let FIAT handle the general case - lmbda = finat.DiscontinuousLagrange - finat_kwargs["variant"] = kind - - elif element.variant() is not None: - finat_kwargs["variant"] = element.variant() - - return lmbda(cell, element.degree(), **finat_kwargs), set() - - -# Element modifiers and compound element types -@convert.register(finat.ufl.BrokenElement) -def convert_brokenelement(element, **kwargs): - finat_elem, deps = _create_element(element._element, **kwargs) - return finat.DiscontinuousElement(finat_elem), deps - - -@convert.register(finat.ufl.EnrichedElement) -def convert_enrichedelement(element, **kwargs): - elements, deps = zip(*[_create_element(elem, **kwargs) - for elem in element._elements]) - return finat.EnrichedElement(elements), set.union(*deps) - - -@convert.register(finat.ufl.NodalEnrichedElement) -def convert_nodalenrichedelement(element, **kwargs): - elements, deps = zip(*[_create_element(elem, **kwargs) - for elem in element._elements]) - return finat.NodalEnrichedElement(elements), set.union(*deps) - - -@convert.register(finat.ufl.MixedElement) -def convert_mixedelement(element, **kwargs): - elements, deps = zip(*[_create_element(elem, **kwargs) - for elem in element.sub_elements]) - return finat.MixedElement(elements), set.union(*deps) - - -@convert.register(finat.ufl.VectorElement) -@convert.register(finat.ufl.TensorElement) -def convert_tensorelement(element, **kwargs): - inner_elem, deps = _create_element(element.sub_elements[0], **kwargs) - shape = element.reference_value_shape - shape = shape[:len(shape) - len(inner_elem.value_shape)] - shape_innermost = kwargs["shape_innermost"] - return (finat.TensorFiniteElement(inner_elem, shape, not shape_innermost), - deps | {"shape_innermost"}) - - -@convert.register(finat.ufl.TensorProductElement) -def convert_tensorproductelement(element, **kwargs): - cell = element.cell - if type(cell) is not ufl.TensorProductCell: - raise ValueError("TensorProductElement not on TensorProductCell?") - shift_axes = kwargs["shift_axes"] - dim_offset = 0 - elements = [] - deps = set() - for elem in element.sub_elements: - kwargs["shift_axes"] = shift_axes + dim_offset - dim_offset += elem.cell.topological_dimension() - finat_elem, ds = _create_element(elem, **kwargs) - elements.append(finat_elem) - deps.update(ds) - return finat.TensorProductElement(elements), deps - - -@convert.register(finat.ufl.HDivElement) -def convert_hdivelement(element, **kwargs): - finat_elem, deps = _create_element(element._element, **kwargs) - return finat.HDivElement(finat_elem), deps - - -@convert.register(finat.ufl.HCurlElement) -def convert_hcurlelement(element, **kwargs): - finat_elem, deps = _create_element(element._element, **kwargs) - return finat.HCurlElement(finat_elem), deps - - -@convert.register(finat.ufl.WithMapping) -def convert_withmapping(element, **kwargs): - return _create_element(element.wrapee, **kwargs) - - -@convert.register(finat.ufl.RestrictedElement) -def convert_restrictedelement(element, **kwargs): - finat_elem, deps = _create_element(element._element, **kwargs) - return finat.RestrictedElement(finat_elem, element.restriction_domain()), deps - - -hexahedron_tpc = ufl.TensorProductCell(ufl.interval, ufl.interval, ufl.interval) -quadrilateral_tpc = ufl.TensorProductCell(ufl.interval, ufl.interval) -_cache = weakref.WeakKeyDictionary() - - -def create_element(ufl_element, shape_innermost=True, shift_axes=0, restriction=None): - """Create a FInAT element (suitable for tabulating with) given a UFL element. - - :arg ufl_element: The UFL element to create a FInAT element from. - :arg shape_innermost: Vector/tensor indices come after basis function indices - :arg restriction: cell restriction in interior facet integrals - (only for runtime tabulated elements) - """ - finat_element, deps = _create_element(ufl_element, - shape_innermost=shape_innermost, - shift_axes=shift_axes, - restriction=restriction) - return finat_element - - -def _create_element(ufl_element, **kwargs): - """A caching wrapper around :py:func:`convert`. - - Takes a UFL element and an unspecified set of parameter options, - and returns the converted element with the set of keyword names - that were relevant for conversion. - """ - # Look up conversion in cache - try: - cache = _cache[ufl_element] - except KeyError: - _cache[ufl_element] = {} - cache = _cache[ufl_element] - - for key, finat_element in cache.items(): - # Cache hit if all relevant parameter values match. - if all(kwargs[param] == value for param, value in key): - return finat_element, set(param for param, value in key) - - # Convert if cache miss - if ufl_element.cell is None: - raise ValueError("Don't know how to build element when cell is not given") - - finat_element, deps = convert(ufl_element, **kwargs) - - # Store conversion in cache - key = frozenset((param, kwargs[param]) for param in deps) - cache[key] = finat_element - - # Forward result - return finat_element, deps - - -def create_base_element(ufl_element, **kwargs): - """Create a "scalar" base FInAT element given a UFL element. - Takes a UFL element and an unspecified set of parameter options, - and returns the converted element. - """ - finat_element = create_element(ufl_element, **kwargs) - if isinstance(finat_element, finat.TensorFiniteElement): - finat_element = finat_element.base_element - return finat_element diff --git a/tsfc/kernel_interface/common.py b/tsfc/kernel_interface/common.py index f757af951f..09bb881678 100644 --- a/tsfc/kernel_interface/common.py +++ b/tsfc/kernel_interface/common.py @@ -15,7 +15,7 @@ from gem.optimise import remove_componenttensors as prune from numpy import asarray from tsfc import fem, ufl_utils -from tsfc.finatinterface import as_fiat_cell, create_element +from finat.element_factory import as_fiat_cell, create_element from tsfc.kernel_interface import KernelInterface from tsfc.logging import logger from ufl.utils.sequences import max_degree @@ -394,7 +394,7 @@ def lower_integral_type(fiat_cell, integral_type): elif integral_type == 'exterior_facet_top': entity_ids = [1] else: - entity_ids = list(range(len(fiat_cell.get_topology()[integration_dim]))) + entity_ids = list(fiat_cell.get_topology()[integration_dim]) return integration_dim, entity_ids diff --git a/tsfc/kernel_interface/firedrake_loopy.py b/tsfc/kernel_interface/firedrake_loopy.py index 0969854cd3..6ace74a0e1 100644 --- a/tsfc/kernel_interface/firedrake_loopy.py +++ b/tsfc/kernel_interface/firedrake_loopy.py @@ -12,7 +12,7 @@ import loopy as lp from tsfc import kernel_args, fem -from tsfc.finatinterface import create_element +from finat.element_factory import create_element from tsfc.kernel_interface.common import KernelBuilderBase as _KernelBuilderBase, KernelBuilderMixin, get_index_names, check_requirements, prepare_coefficient, prepare_arguments, prepare_constant from tsfc.loopy import generate as generate_loopy From e65ec71c06b602e275c1b782ae702189956e3de5 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 19 Dec 2024 11:17:39 -0600 Subject: [PATCH 3/3] remove test_create_fia(n)t_element --- tests/tsfc/test_create_fiat_element.py | 150 ------------------------ tests/tsfc/test_create_finat_element.py | 138 ---------------------- 2 files changed, 288 deletions(-) delete mode 100644 tests/tsfc/test_create_fiat_element.py delete mode 100644 tests/tsfc/test_create_finat_element.py diff --git a/tests/tsfc/test_create_fiat_element.py b/tests/tsfc/test_create_fiat_element.py deleted file mode 100644 index f99053ab97..0000000000 --- a/tests/tsfc/test_create_fiat_element.py +++ /dev/null @@ -1,150 +0,0 @@ -import pytest - -import FIAT -from FIAT.discontinuous_lagrange import DiscontinuousLagrange as FIAT_DiscontinuousLagrange - -import ufl -import finat.ufl -from finat.element_factory import create_element as _create_element - - -supported_elements = { - # These all map directly to FIAT elements - "Brezzi-Douglas-Marini": FIAT.BrezziDouglasMarini, - "Brezzi-Douglas-Fortin-Marini": FIAT.BrezziDouglasFortinMarini, - "Lagrange": FIAT.Lagrange, - "Nedelec 1st kind H(curl)": FIAT.Nedelec, - "Nedelec 2nd kind H(curl)": FIAT.NedelecSecondKind, - "Raviart-Thomas": FIAT.RaviartThomas, - "Regge": FIAT.Regge, -} -"""A :class:`.dict` mapping UFL element family names to their -FIAT-equivalent constructors.""" - - -def create_element(ufl_element): - """Create a FIAT element given a UFL element.""" - finat_element = _create_element(ufl_element) - return finat_element.fiat_equivalent - - -@pytest.fixture(params=["BDM", - "BDFM", - "Lagrange", - "N1curl", - "N2curl", - "RT", - "Regge"]) -def triangle_names(request): - return request.param - - -@pytest.fixture -def ufl_element(triangle_names): - return finat.ufl.FiniteElement(triangle_names, ufl.triangle, 2) - - -def test_triangle_basic(ufl_element): - element = create_element(ufl_element) - assert isinstance(element, supported_elements[ufl_element.family()]) - - -@pytest.fixture(params=["CG", "DG", "DG L2"], scope="module") -def tensor_name(request): - return request.param - - -@pytest.fixture(params=[ufl.interval, ufl.triangle, - ufl.quadrilateral], - ids=lambda x: x.cellname(), - scope="module") -def ufl_A(request, tensor_name): - return finat.ufl.FiniteElement(tensor_name, request.param, 1) - - -@pytest.fixture -def ufl_B(tensor_name): - return finat.ufl.FiniteElement(tensor_name, ufl.interval, 1) - - -def test_tensor_prod_simple(ufl_A, ufl_B): - tensor_ufl = finat.ufl.TensorProductElement(ufl_A, ufl_B) - - tensor = create_element(tensor_ufl) - A = create_element(ufl_A) - B = create_element(ufl_B) - - assert isinstance(tensor, FIAT.TensorProductElement) - - assert tensor.A is A - assert tensor.B is B - - -@pytest.mark.parametrize(('family', 'expected_cls'), - [('P', FIAT.GaussLobattoLegendre), - ('DP', FIAT.GaussLegendre), - ('DP L2', FIAT.GaussLegendre)]) -def test_interval_variant_default(family, expected_cls): - ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3) - assert isinstance(create_element(ufl_element), expected_cls) - - -@pytest.mark.parametrize(('family', 'variant', 'expected_cls'), - [('P', 'equispaced', FIAT.Lagrange), - ('P', 'spectral', FIAT.GaussLobattoLegendre), - ('DP', 'equispaced', FIAT_DiscontinuousLagrange), - ('DP', 'spectral', FIAT.GaussLegendre), - ('DP L2', 'equispaced', FIAT_DiscontinuousLagrange), - ('DP L2', 'spectral', FIAT.GaussLegendre)]) -def test_interval_variant(family, variant, expected_cls): - ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3, variant=variant) - assert isinstance(create_element(ufl_element), expected_cls) - - -def test_triangle_variant_spectral(): - ufl_element = finat.ufl.FiniteElement('DP', ufl.triangle, 2, variant='spectral') - create_element(ufl_element) - - -def test_triangle_variant_spectral_l2(): - ufl_element = finat.ufl.FiniteElement('DP L2', ufl.triangle, 2, variant='spectral') - create_element(ufl_element) - - -def test_quadrilateral_variant_spectral_q(): - element = create_element(finat.ufl.FiniteElement('Q', ufl.quadrilateral, 3, variant='spectral')) - assert isinstance(element.element.A, FIAT.GaussLobattoLegendre) - assert isinstance(element.element.B, FIAT.GaussLobattoLegendre) - - -def test_quadrilateral_variant_spectral_dq(): - element = create_element(finat.ufl.FiniteElement('DQ', ufl.quadrilateral, 1, variant='spectral')) - assert isinstance(element.element.A, FIAT.GaussLegendre) - assert isinstance(element.element.B, FIAT.GaussLegendre) - - -def test_quadrilateral_variant_spectral_dq_l2(): - element = create_element(finat.ufl.FiniteElement('DQ L2', ufl.quadrilateral, 1, variant='spectral')) - assert isinstance(element.element.A, FIAT.GaussLegendre) - assert isinstance(element.element.B, FIAT.GaussLegendre) - - -def test_quadrilateral_variant_spectral_rtcf(): - element = create_element(finat.ufl.FiniteElement('RTCF', ufl.quadrilateral, 2, variant='spectral')) - assert isinstance(element.element._elements[0].A, FIAT.GaussLobattoLegendre) - assert isinstance(element.element._elements[0].B, FIAT.GaussLegendre) - assert isinstance(element.element._elements[1].A, FIAT.GaussLegendre) - assert isinstance(element.element._elements[1].B, FIAT.GaussLobattoLegendre) - - -def test_cache_hit(ufl_element): - A = create_element(ufl_element) - B = create_element(ufl_element) - - assert A is B - - -if __name__ == "__main__": - import os - import sys - pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:]) diff --git a/tests/tsfc/test_create_finat_element.py b/tests/tsfc/test_create_finat_element.py deleted file mode 100644 index 7964824c2c..0000000000 --- a/tests/tsfc/test_create_finat_element.py +++ /dev/null @@ -1,138 +0,0 @@ -import pytest - -import ufl -import finat.ufl -import finat -from finat.element_factory import create_element, supported_elements - - -@pytest.fixture(params=["BDM", - "BDFM", - "Lagrange", - "N1curl", - "N2curl", - "RT", - "Regge"]) -def triangle_names(request): - return request.param - - -@pytest.fixture -def ufl_element(triangle_names): - return finat.ufl.FiniteElement(triangle_names, ufl.triangle, 2) - - -def test_triangle_basic(ufl_element): - element = create_element(ufl_element) - assert isinstance(element, supported_elements[ufl_element.family()]) - - -@pytest.fixture -def ufl_vector_element(triangle_names): - return finat.ufl.VectorElement(triangle_names, ufl.triangle, 2) - - -def test_triangle_vector(ufl_element, ufl_vector_element): - scalar = create_element(ufl_element) - vector = create_element(ufl_vector_element) - - assert isinstance(vector, finat.TensorFiniteElement) - assert scalar == vector.base_element - - -@pytest.fixture(params=["CG", "DG", "DG L2"]) -def tensor_name(request): - return request.param - - -@pytest.fixture(params=[ufl.interval, ufl.triangle, - ufl.quadrilateral], - ids=lambda x: x.cellname()) -def ufl_A(request, tensor_name): - return finat.ufl.FiniteElement(tensor_name, request.param, 1) - - -@pytest.fixture -def ufl_B(tensor_name): - return finat.ufl.FiniteElement(tensor_name, ufl.interval, 1) - - -def test_tensor_prod_simple(ufl_A, ufl_B): - tensor_ufl = finat.ufl.TensorProductElement(ufl_A, ufl_B) - - tensor = create_element(tensor_ufl) - A = create_element(ufl_A) - B = create_element(ufl_B) - - assert isinstance(tensor, finat.TensorProductElement) - - assert tensor.factors == (A, B) - - -@pytest.mark.parametrize(('family', 'expected_cls'), - [('P', finat.GaussLobattoLegendre), - ('DP', finat.GaussLegendre), - ('DP L2', finat.GaussLegendre)]) -def test_interval_variant_default(family, expected_cls): - ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3) - assert isinstance(create_element(ufl_element), expected_cls) - - -@pytest.mark.parametrize(('family', 'variant', 'expected_cls'), - [('P', 'equispaced', finat.Lagrange), - ('P', 'spectral', finat.GaussLobattoLegendre), - ('DP', 'equispaced', finat.DiscontinuousLagrange), - ('DP', 'spectral', finat.GaussLegendre), - ('DP L2', 'equispaced', finat.DiscontinuousLagrange), - ('DP L2', 'spectral', finat.GaussLegendre)]) -def test_interval_variant(family, variant, expected_cls): - ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3, variant=variant) - assert isinstance(create_element(ufl_element), expected_cls) - - -def test_triangle_variant_spectral(): - ufl_element = finat.ufl.FiniteElement('DP', ufl.triangle, 2, variant='spectral') - create_element(ufl_element) - - -def test_triangle_variant_spectral_l2(): - ufl_element = finat.ufl.FiniteElement('DP L2', ufl.triangle, 2, variant='spectral') - create_element(ufl_element) - - -def test_quadrilateral_variant_spectral_q(): - element = create_element(finat.ufl.FiniteElement('Q', ufl.quadrilateral, 3, variant='spectral')) - assert isinstance(element.product.factors[0], finat.GaussLobattoLegendre) - assert isinstance(element.product.factors[1], finat.GaussLobattoLegendre) - - -def test_quadrilateral_variant_spectral_dq(): - element = create_element(finat.ufl.FiniteElement('DQ', ufl.quadrilateral, 1, variant='spectral')) - assert isinstance(element.product.factors[0], finat.GaussLegendre) - assert isinstance(element.product.factors[1], finat.GaussLegendre) - - -def test_quadrilateral_variant_spectral_dq_l2(): - element = create_element(finat.ufl.FiniteElement('DQ L2', ufl.quadrilateral, 1, variant='spectral')) - assert isinstance(element.product.factors[0], finat.GaussLegendre) - assert isinstance(element.product.factors[1], finat.GaussLegendre) - - -def test_cache_hit(ufl_element): - A = create_element(ufl_element) - B = create_element(ufl_element) - - assert A is B - - -def test_cache_hit_vector(ufl_vector_element): - A = create_element(ufl_vector_element) - B = create_element(ufl_vector_element) - - assert A is B - - -if __name__ == "__main__": - import os - import sys - pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:])