Skip to content

Commit

Permalink
remove math.prod, test up to degree 7 in 2D/3D
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Oct 28, 2023
1 parent 1ff2114 commit b3a488d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
15 changes: 9 additions & 6 deletions FIAT/expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import sympy
from FIAT import reference_element
from FIAT import jacobi
from math import prod

from FIAT.reference_element import UFCInterval
from FIAT.quadrature import GaussLegendreQuadratureLineRule
Expand Down Expand Up @@ -56,7 +55,6 @@ def dubiner_deriv_1d(order, dim, x):
derivs += results * (-0.5*j)
if j > 1:
derivs *= xhat ** (j - 1)

indices = [flat_index(i, j) for i in range(n + 1)]
dphi[indices, :] = derivs
return dphi
Expand All @@ -65,11 +63,16 @@ def dubiner_deriv_1d(order, dim, x):
def duffy_chain_rule(A, eta, tabulations):
dphi_dxi = [tabulations[alpha] for alpha in sorted(tabulations, reverse=True) if sum(alpha) == 1]
dim = len(eta)
eta1 = [(1. - x) * 0.5 for x in eta]
for i in range(dim):
for j in range(i):
dphi_dxi[i] += dphi_dxi[j] * (1. + eta[j])*0.5 * prod((1. - eta[k])*0.5 for k in range(j+1, dim) if k != i)
dphi_dxi[i] /= prod((1. - x)*0.5 for x in eta[i+1:])

Jij = -0.5 * (1. + eta[j])
for k in range(j + 1, dim):
if k != i:
Jij *= eta1[k]
dphi_dxi[i] -= dphi_dxi[j] * Jij
for j in range(i + 1, dim):
dphi_dxi[i] /= eta1[j]
k = 0
dphi_dx = [sum(dphi_dxi[j] * A[j][i] for j in range(dim)) for i in range(dim)]
for alpha in sorted(tabulations, reverse=True):
Expand Down Expand Up @@ -297,7 +300,7 @@ def _tabulate_duffy(self, n, pts):
for alpha in tabulations:
results = tabulations[alpha]
for k in range(n+1):
results[k, :] *= (k + 0.5)**0.5
results[k] *= (k + 0.5)**0.5
return tabulations

def tabulate_derivatives(self, n, pts):
Expand Down
3 changes: 2 additions & 1 deletion test/unit/test_gauss_legendre.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def symmetric_simplex(dim):
return s


@pytest.mark.parametrize("dim, degree", sum(([(d, p) for p in range(0, 8-d)] for d in range(1, 4)), []))
@pytest.mark.parametrize("degree", range(0, 8))
@pytest.mark.parametrize("dim", (1, 2, 3))
def test_gl_basis_values(dim, degree):
"""Ensure that integrating a simple monomial produces the expected results."""
from FIAT import GaussLegendre, make_quadrature
Expand Down
3 changes: 2 additions & 1 deletion test/unit/test_gauss_lobatto_legendre.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def symmetric_simplex(dim):
return s


@pytest.mark.parametrize("dim, degree", sum(([(d, p) for p in range(1, 8-d)] for d in range(1, 4)), []))
@pytest.mark.parametrize("degree", range(1, 8))
@pytest.mark.parametrize("dim", (1, 2, 3))
def test_gll_basis_values(dim, degree):
"""Ensure that integrating a simple monomial produces the expected results."""
from FIAT import GaussLobattoLegendre, make_quadrature
Expand Down

0 comments on commit b3a488d

Please sign in to comment.