Skip to content

Commit

Permalink
Merge pull request firedrakeproject#115 from FInAT/pbrubeck/fix/point…
Browse files Browse the repository at this point in the history
…-evaluation

Tabulate Ciarlet generically
  • Loading branch information
dham authored Nov 22, 2023
2 parents cd26c09 + c4e3b58 commit a7080b1
Showing 1 changed file with 0 additions and 69 deletions.
69 changes: 0 additions & 69 deletions finat/fiat_elements.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import numpy as np
import sympy as sp
from functools import singledispatch

import FIAT
from FIAT.polynomial_set import mis, form_matrix_product

import gem
from gem.utils import cached_property
Expand Down Expand Up @@ -176,7 +174,6 @@ def point_evaluation(self, order, refcoords, entity=None):
esd = self.cell.construct_subelement(entity_dim).get_spatial_dimension()
assert isinstance(refcoords, gem.Node) and refcoords.shape == (esd,)

# Dispatch on FIAT element class
return point_evaluation(self._element, order, refcoords, (entity_dim, entity_i))

@cached_property
Expand Down Expand Up @@ -268,13 +265,7 @@ def mapping(self):
return result


@singledispatch
def point_evaluation(fiat_element, order, refcoords, entity):
raise AssertionError("FIAT element expected!")


@point_evaluation.register(FIAT.FiniteElement)
def point_evaluation_generic(fiat_element, order, refcoords, entity):
# Coordinates on the reference entity (SymPy)
esd, = refcoords.shape
Xi = sp.symbols('X Y Z')[:esd]
Expand Down Expand Up @@ -317,66 +308,6 @@ def point_evaluation_generic(fiat_element, order, refcoords, entity):
return result


@point_evaluation.register(FIAT.CiarletElement)
def point_evaluation_ciarlet(fiat_element, order, refcoords, entity):
# Coordinates on the reference entity (SymPy)
esd, = refcoords.shape
Xi = sp.symbols('X Y Z')[:esd]

# Coordinates on the reference cell
cell = fiat_element.get_reference_element()
X = cell.get_entity_transform(*entity)(Xi)

# Evaluate expansion set at SymPy point
poly_set = fiat_element.get_nodal_basis()
degree = poly_set.get_embedded_degree()
base_values = poly_set.get_expansion_set().tabulate(degree, [X])
m = len(base_values)
assert base_values.shape == (m, 1)
base_values_sympy = np.array(list(base_values.flat))

# Find constant polynomials
def is_const(expr):
try:
float(expr)
return True
except TypeError:
return False
const_mask = np.array(list(map(is_const, base_values_sympy)))

# Convert SymPy expression to GEM
mapper = gem.node.Memoizer(sympy2gem)
mapper.bindings = {s: gem.Indexed(refcoords, (i,))
for i, s in enumerate(Xi)}
base_values = gem.ListTensor(list(map(mapper, base_values.flat)))

# Populate result dict, creating precomputed coefficient
# matrices for each derivative tuple.
result = {}
for i in range(order + 1):
for alpha in mis(cell.get_spatial_dimension(), i):
D = form_matrix_product(poly_set.get_dmats(), alpha)
table = np.dot(poly_set.get_coeffs(), np.transpose(D))
assert table.shape[-1] == m
zerocols = np.isclose(abs(table).max(axis=tuple(range(table.ndim - 1))), 0.0)
if all(np.logical_or(const_mask, zerocols)):
# Casting is safe by assertion of is_const
vals = base_values_sympy[const_mask].astype(np.float64)
result[alpha] = gem.Literal(table[..., const_mask].dot(vals))
else:
beta = tuple(gem.Index() for s in table.shape[:-1])
k = gem.Index()
result[alpha] = gem.ComponentTensor(
gem.IndexSum(
gem.Product(gem.Indexed(gem.Literal(table), beta + (k,)),
gem.Indexed(base_values, (k,))),
(k,)
),
beta
)
return result


class Regge(FiatElement): # naturally tensor valued
def __init__(self, cell, degree):
super(Regge, self).__init__(FIAT.Regge(cell, degree))
Expand Down

0 comments on commit a7080b1

Please sign in to comment.