Skip to content

Commit

Permalink
Merge pull request #54 from firedrakeproject/wence/fix/serendipity-de…
Browse files Browse the repository at this point in the history
…rivs

Broadcast result of lambdify in Serendipity tabulation
  • Loading branch information
wence- authored Jul 22, 2020
2 parents d9aac42 + 8cac8ab commit e16d9f3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
7 changes: 5 additions & 2 deletions FIAT/serendipity.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def tabulate(self, order, points, entity=None):
raise NotImplementedError('no tabulate method for serendipity elements of dimension 1 or less.')
if dim >= 4:
raise NotImplementedError('tabulate does not support higher dimensions than 3.')
points = np.asarray(points)
npoints, pointdim = points.shape
for o in range(order + 1):
alphas = mis(dim, o)
for alpha in alphas:
Expand All @@ -160,8 +162,9 @@ def tabulate(self, order, points, entity=None):
callable = lambdify(variables[:dim], polynomials, modules="numpy", dummify=True)
self.basis[alpha] = polynomials
self.basis_callable[alpha] = callable
points = np.asarray(points)
T = np.asarray(callable(*(points[:, i] for i in range(points.shape[1]))))
tabulation = callable(*(points[:, i] for i in range(pointdim)))
T = np.asarray([np.broadcast_to(tab, (npoints, ))
for tab in tabulation])
phivals[alpha] = T
return phivals

Expand Down
32 changes: 32 additions & 0 deletions test/unit/test_serendipity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from FIAT.reference_element import UFCQuadrilateral
from FIAT import Serendipity
import numpy as np
import sympy


def test_serendipity_derivatives():
cell = UFCQuadrilateral()
S = Serendipity(cell, 2)

x = sympy.DeferredVector("X")
X, Y = x[0], x[1]
basis_functions = [
(1 - X)*(1 - Y),
Y*(1 - X),
X*(1 - Y),
X*Y,
Y*(1 - X)*(Y - 1),
X*Y*(Y - 1),
X*(1 - Y)*(X - 1),
X*Y*(X - 1),
]
points = [[0.5, 0.5], [0.25, 0.75]]
for alpha, actual in S.tabulate(2, points).items():
expect = list(sympy.diff(basis, *zip([X, Y], alpha))
for basis in basis_functions)
expect = list([basis.subs(dict(zip([X, Y], point)))
for point in points]
for basis in expect)
assert actual.shape == (8, 2)
assert np.allclose(np.asarray(expect, dtype=float),
actual.reshape(8, 2))

0 comments on commit e16d9f3

Please sign in to comment.