Skip to content

Commit

Permalink
Common Duffy tabulation
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Oct 29, 2023
1 parent daaab1e commit 9064496
Showing 1 changed file with 30 additions and 54 deletions.
84 changes: 30 additions & 54 deletions FIAT/expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __new__(cls, ref_el, *args, **kwargs):
elif ref_el.get_shape() == reference_element.TETRAHEDRON:
return TetrahedronExpansionSet(ref_el)
else:
raise Exception("Unknown reference element type.")
raise ValueError("Invalid reference element type.")

def __init__(self, ref_el):
self.ref_el = ref_el
Expand All @@ -244,7 +244,35 @@ def __init__(self, ref_el):
self._dmats_cache = {}

def _tabulate_duffy(self, n, pts):
raise NotImplementedError()
from FIAT.polynomial_set import mis
dim = self.ref_el.get_spatial_dimension()
sd = self.get_num_members(n)
xi = numpy.transpose(numpy.dot(pts, self.A.T) + self.b)
eta = (lambda x: x, lambda x: x, eta_square, eta_cube)[dim](xi)
basis = [dubiner_1d(n, k, eta[k]) for k in range(dim)]
derivs = [dubiner_deriv_1d(n, k, eta[k]) for k in range(dim)]
alphas = mis(dim, 0) + mis(dim, 1)
tabulations = {}
for alpha in alphas:
V = [v if a == 0 else dv for a, v, dv in zip(alpha, basis, derivs)]
phi = V[0]
if dim >= 2:
phi1 = phi
phi = numpy.copy(V[1])
for i in range(n + 1):
indices = [morton_index2(i, j) for j in range(n + 1 - i)]
phi[indices] *= phi1[i]
if dim >= 3:
phi2 = phi
phi = numpy.zeros((sd, V[0].shape[1]), dtype=V[0].dtype)
for i in range(n + 1):
for j in range(n + 1 - i):
Vij = phi2[morton_index2(i, j)]
for k in range(n + 1 - i - j):
phi[morton_index3(i, j, k)] = V[2][morton_index2(i + j, k)] * Vij
tabulations[alpha] = phi
duffy_chain_rule(self.A, eta, tabulations)
return tabulations

def make_dmats(self, degree):
cache = self._dmats_cache
Expand Down Expand Up @@ -313,12 +341,6 @@ def tabulate(self, n, pts):
else:
return []

def _tabulate_duffy(self, n, pts):
xi = numpy.dot(pts, self.A.T) + self.b
tabulations = {(0,): dubiner_1d(n, 0, xi),
(1,): dubiner_deriv_1d(n, 0, xi) * self.A[0][0]}
return tabulations

def tabulate_derivatives(self, n, pts):
"""Returns a tuple of length one (A,) such that
A[i,j] = D phi_i(pts[j]). The tuple is returned for
Expand Down Expand Up @@ -414,28 +436,6 @@ def _tabulate(self, n, pts):
return results
# return self.scale * results

def _tabulate_duffy(self, n, pts):
from FIAT.polynomial_set import mis
idx = morton_index2
sd = self.get_num_members(n)
xi = numpy.transpose(numpy.dot(pts, self.A.T) + self.b)
eta = eta_square(xi)
dim = len(eta)
basis = [dubiner_1d(n, k, eta[k]) for k in range(dim)]
derivs = [dubiner_deriv_1d(n, k, eta[k]) for k in range(dim)]
alphas = mis(dim, 0) + mis(dim, 1)
tabulations = {}
for alpha in alphas:
V = [v if a == 0 else dv for a, v, dv in zip(alpha, basis, derivs)]
phi = numpy.zeros((sd, V[0].shape[1]), dtype=V[0].dtype)
for i in range(n + 1):
Vi = V[0][i]
for j in range(n + 1 - i):
phi[idx(i, j)] = V[1][morton_index2(i, j)] * Vi
tabulations[alpha] = phi
duffy_chain_rule(self.A, eta, tabulations)
return tabulations

def tabulate_derivatives(self, n, pts):
order = 1
data = _tabulate_dpts(self._tabulate, 2, n, order, numpy.array(pts))
Expand Down Expand Up @@ -539,30 +539,6 @@ def _tabulate(self, n, pts):

return results

def _tabulate_duffy(self, n, pts):
from FIAT.polynomial_set import mis
idx = morton_index3
sd = self.get_num_members(n)
xi = numpy.transpose(numpy.dot(pts, self.A.T) + self.b)
eta = eta_cube(xi)
dim = len(eta)
basis = [dubiner_1d(n, k, eta[k]) for k in range(dim)]
derivs = [dubiner_deriv_1d(n, k, eta[k]) for k in range(dim)]
alphas = mis(dim, 0) + mis(dim, 1)
tabulations = {}
for alpha in alphas:
V = [v if a == 0 else dv for a, v, dv in zip(alpha, basis, derivs)]
phi = numpy.zeros((sd, V[0].shape[1]), dtype=V[0].dtype)
for i in range(n + 1):
Vi = V[0][i]
for j in range(n + 1 - i):
Vij = V[1][morton_index2(i, j)] * Vi
for k in range(n + 1 - i - j):
phi[idx(i, j, k)] = V[2][morton_index2(i + j, k)] * Vij
tabulations[alpha] = phi
duffy_chain_rule(self.A, eta, tabulations)
return tabulations

def tabulate_derivatives(self, n, pts):
order = 1
D = 3
Expand Down

0 comments on commit 9064496

Please sign in to comment.