Skip to content

Commit

Permalink
tsfc: introduce MixedMesh
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Dec 11, 2024
1 parent 571eeb0 commit 5fecff4
Show file tree
Hide file tree
Showing 9 changed files with 477 additions and 296 deletions.
5 changes: 3 additions & 2 deletions tests/tsfc/test_tsfc_182.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from ufl import Coefficient, TestFunction, dx, inner, tetrahedron, Mesh, FunctionSpace
from ufl import Coefficient, TestFunction, dx, inner, tetrahedron, Mesh, MixedMesh, FunctionSpace
from finat.ufl import FiniteElement, MixedElement, VectorElement

from tsfc import compile_form
Expand All @@ -20,7 +20,8 @@ def test_delta_elimination(mode):

element_chi_lambda = MixedElement(element_eps_p, element_lambda)
domain = Mesh(VectorElement("Lagrange", tetrahedron, 1))
space = FunctionSpace(domain, element_chi_lambda)
domains = MixedMesh(domain, domain)
space = FunctionSpace(domains, element_chi_lambda)

chi_lambda = Coefficient(space)
delta_chi_lambda = TestFunction(space)
Expand Down
5 changes: 3 additions & 2 deletions tests/tsfc/test_tsfc_204.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from tsfc import compile_form
from ufl import (Coefficient, FacetNormal,
FunctionSpace, Mesh, as_matrix,
FunctionSpace, Mesh, MixedMesh, as_matrix,
dot, dS, ds, dx, facet, grad, inner, outer, split, triangle)
from finat.ufl import BrokenElement, FiniteElement, MixedElement, VectorElement


def test_physically_mapped_facet():
mesh = Mesh(VectorElement("P", triangle, 1))
meshes = MixedMesh(mesh, mesh, mesh, mesh, mesh)

# set up variational problem
U = FiniteElement("Morley", mesh.ufl_cell(), 2)
Expand All @@ -15,7 +16,7 @@ def test_physically_mapped_facet():
Vv = VectorElement(BrokenElement(V))
Qhat = VectorElement(BrokenElement(V[facet]), dim=2)
Vhat = VectorElement(V[facet], dim=2)
Z = FunctionSpace(mesh, MixedElement(U, Vv, Qhat, Vhat, R))
Z = FunctionSpace(meshes, MixedElement(U, Vv, Qhat, Vhat, R))

z = Coefficient(Z)
u, d, qhat, dhat, lam = split(z)
Expand Down
73 changes: 42 additions & 31 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ufl.algorithms import extract_arguments, extract_coefficients
from ufl.algorithms.analysis import has_type
from ufl.classes import Form, GeometricQuantity
from ufl.domain import extract_unique_domain
from ufl.domain import extract_unique_domain, extract_domains

import gem
import gem.impero_utils as impero_utils
Expand All @@ -26,9 +26,9 @@


TSFCIntegralDataInfo = collections.namedtuple("TSFCIntegralDataInfo",
["domain", "integral_type", "subdomain_id", "domain_number",
["domain", "integral_type", "subdomain_id", "domain_number", "domain_integral_type_map",
"arguments",
"coefficients", "coefficient_numbers"])
"coefficients", "coefficient_split", "coefficient_numbers"])
TSFCIntegralDataInfo.__doc__ = """
Minimal set of objects for kernel builders.
Expand All @@ -47,7 +47,7 @@
"""


def compile_form(form, prefix="form", parameters=None, interface=None, diagonal=False, log=False):
def compile_form(form, prefix="form", parameters=None, dont_split_numbers=(), diagonal=False, log=False):
"""Compiles a UFL form into a set of assembly kernels.
:arg form: UFL form
Expand All @@ -65,76 +65,85 @@ def compile_form(form, prefix="form", parameters=None, interface=None, diagonal=

# Determine whether in complex mode:
complex_mode = parameters and is_complex(parameters.get("scalar_type"))
fd = ufl_utils.compute_form_data(form, complex_mode=complex_mode)
form_data = ufl_utils.compute_form_data(form,
do_split_coefficients=tuple(c for i, c in enumerate(form.coefficients()) if i not in dont_split_numbers),
complex_mode=complex_mode)
logger.info(GREEN % "compute_form_data finished in %g seconds.", time.time() - cpu_time)

kernels = []
for integral_data in fd.integral_data:
for integral_data in form_data.integral_data:
start = time.time()
kernel = compile_integral(integral_data, fd, prefix, parameters, interface=interface, diagonal=diagonal, log=log)
if kernel is not None:
kernels.append(kernel)
if integral_data.integrals:
kernel = compile_integral(integral_data, form_data, prefix, parameters, diagonal=diagonal, log=log)
if kernel is not None:
kernels.append(kernel)
logger.info(GREEN % "compile_integral finished in %g seconds.", time.time() - start)

logger.info(GREEN % "TSFC finished in %g seconds.", time.time() - cpu_time)
return kernels


def compile_integral(integral_data, form_data, prefix, parameters, interface, *, diagonal=False, log=False):
def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=False, log=False):
"""Compiles a UFL integral into an assembly kernel.
:arg integral_data: UFL integral data
:arg form_data: UFL form data
:arg prefix: kernel name will start with this string
:arg parameters: parameters object
:arg interface: backend module for the kernel interface
:arg diagonal: Are we building a kernel for the diagonal of a rank-2 element tensor?
:arg log: bool if the Kernel should be profiled with Log events
:returns: a kernel constructed by the kernel interface
"""
parameters = preprocess_parameters(parameters)
if interface is None:
interface = firedrake_interface_loopy.KernelBuilder
scalar_type = parameters["scalar_type"]
integral_type = integral_data.integral_type
if integral_type.startswith("interior_facet") and diagonal:
raise NotImplementedError("Sorry, we can't assemble the diagonal of a form for interior facet integrals")
mesh = integral_data.domain
arguments = form_data.preprocessed_form.arguments()
kernel_name = f"{prefix}_{integral_type}_integral"
# Dict mapping domains to index in original_form.ufl_domains()
domain_numbering = form_data.original_form.domain_numbering()
domain_number = domain_numbering[integral_data.domain]
coefficients = [form_data.function_replace_map[c] for c in integral_data.integral_coefficients]
# This is which coefficient in the original form the
# current coefficient is.
# Consider f*v*dx + g*v*ds, the full form contains two
# coefficients, but each integral only requires one.
coefficient_numbers = tuple(form_data.original_coefficient_positions[i]
for i, (_, enabled) in enumerate(zip(form_data.reduced_coefficients, integral_data.enabled_coefficients))
if enabled)
coefficients = []
coefficient_split = {}
coefficient_numbers = []
for i, (coeff_orig, enabled) in enumerate(zip(form_data.reduced_coefficients, integral_data.enabled_coefficients)):
if enabled:
coeff = form_data.function_replace_map[coeff_orig]
coefficients.append(coeff)
if coeff in form_data.coefficient_split:
coefficient_split[coeff] = form_data.coefficient_split[coeff]
coefficient_numbers.append(form_data.original_coefficient_positions[i])
mesh = integral_data.domain
all_meshes = extract_domains(form_data.original_form)
domain_number = all_meshes.index(mesh)
integral_data_info = TSFCIntegralDataInfo(domain=integral_data.domain,
integral_type=integral_data.integral_type,
subdomain_id=integral_data.subdomain_id,
domain_number=domain_number,
domain_integral_type_map={mesh: integral_data.domain_integral_type_map[mesh] if mesh in integral_data.domain_integral_type_map else None for mesh in all_meshes},
arguments=arguments,
coefficients=coefficients,
coefficient_split=coefficient_split,
coefficient_numbers=coefficient_numbers)
builder = interface(integral_data_info,
scalar_type,
diagonal=diagonal)
builder.set_coordinates(mesh)
builder.set_cell_sizes(mesh)
builder.set_coefficients(integral_data, form_data)
builder = firedrake_interface_loopy.KernelBuilder(integral_data_info,
scalar_type,
diagonal=diagonal)
builder.set_entity_numbers(all_meshes)
builder.set_entity_orientations(all_meshes)
builder.set_coordinates(all_meshes)
builder.set_cell_orientations(all_meshes)
builder.set_cell_sizes(all_meshes)
builder.set_coefficients()
# TODO: We do not want pass constants to kernels that do not need them
# so we should attach the constants to integral data instead
builder.set_constants(form_data.constants)
ctx = builder.create_context()
for integral in integral_data.integrals:
params = parameters.copy()
params.update(integral.metadata()) # integral metadata overrides
integrand = ufl.replace(integral.integrand(), form_data.function_replace_map)
integrand_exprs = builder.compile_integrand(integrand, params, ctx)
integrand_exprs = builder.compile_integrand(integral.integrand(), params, ctx)
integral_exprs = builder.construct_integrals(integrand_exprs, params)
builder.stash_integrals(integral_exprs, params, ctx)
return builder.construct_kernel(kernel_name, ctx, log)
Expand Down Expand Up @@ -207,6 +216,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
if domain is None:
domain = extract_unique_domain(expression)
assert domain is not None
builder._domain_integral_type_map = {domain: "cell"}

# Collect required coefficients and determine numbering
coefficients = extract_coefficients(expression)
Expand All @@ -219,7 +229,8 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
# Create a fake coordinate coefficient for a domain.
coords_coefficient = ufl.Coefficient(ufl.FunctionSpace(domain, domain.ufl_coordinate_element()))
builder.domain_coordinate[domain] = coords_coefficient
builder.set_cell_sizes(domain)
builder.set_cell_orientations((domain, ))
builder.set_cell_sizes((domain, ))
coefficients = [coords_coefficient] + coefficients
needs_external_coords = True
builder.set_coefficients(coefficients)
Expand All @@ -235,7 +246,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
ufl_cell=domain.ufl_cell(),
# FIXME: change if we ever implement
# interpolation on facets.
integral_type="cell",
domain_integral_type_map={domain: "cell"},
argument_multiindices=argument_multiindices,
index_cache={},
scalar_type=parameters["scalar_type"])
Expand Down
Loading

0 comments on commit 5fecff4

Please sign in to comment.