From 404e30db3eda301213a124a68edd99ebefc73f8c Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 1 Nov 2023 10:17:55 +0000 Subject: [PATCH] tidy up recurrence --- FIAT/expansions.py | 50 +++++++++++++++++++--------------------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/FIAT/expansions.py b/FIAT/expansions.py index 0831fb6b1..7d17e8029 100644 --- a/FIAT/expansions.py +++ b/FIAT/expansions.py @@ -52,15 +52,15 @@ def recurrence(dim, n, factors, phi, dfactors=None, dphi=None): df5 = 2 * f4 * df4 # p = 1 - phi[idx(1)] = f1 + icur = idx(0) + inext = idx(1) + phi[inext] = f1 if not skip_derivs: - dphi[idx(1)] = df1 - + dphi[inext] = df1 # general p by recurrence for p in range(1, n): - icur = idx(p) + iprev, icur = icur, inext inext = idx(p + 1) - iprev = idx(p - 1) a = (2. * p + 1.) / (1. + p) b = p / (1. + p) phi[inext] = a * f1 * phi[icur] - b * f2 * phi[iprev] @@ -71,23 +71,19 @@ def recurrence(dim, n, factors, phi, dfactors=None, dphi=None): if dim < 2: return - # q = 1 for p in range(n): + # q = 1 icur = idx(p, 0) inext = idx(p, 1) g = (p + 1.5) * f3 - f4 phi[inext] = g * phi[icur] - if skip_derivs: - continue - dg = (p + 1.5) * df3 - df4 - dphi[inext] = g * dphi[icur] + phi[icur] * dg - - # general q by recurrence - for p in range(n - 1): + if not skip_derivs: + dg = (p + 1.5) * df3 - df4 + dphi[inext] = g * dphi[icur] + phi[icur] * dg + # general q by recurrence for q in range(1, n - p): - icur = idx(p, q) + iprev, icur = icur, inext inext = idx(p, q + 1) - iprev = idx(p, q - 1) aq, bq, cq = jrc(2 * p + 1, 0, q) g = aq * f3 + (bq - aq) * f4 h = cq * f5 @@ -103,32 +99,28 @@ def recurrence(dim, n, factors, phi, dfactors=None, dphi=None): z = 1 - 2 * f4 if dfactors: dz = -2 * df4 - # r = 1 + for p in range(n): - for q in range(n - p): + for q in range(0, n - p): + # r = 1 icur = idx(p, q, 0) inext = idx(p, q, 1) a = 2.0 + p + q b = 1.0 + p + q g = a * z + b phi[inext] = g * phi[icur] - if skip_derivs: - continue - dg = a * dz - dphi[inext] = g * dphi[icur] + phi[icur] * dg - - # general r by recurrence - for p in range(n - 1): - for q in range(0, n - p - 1): + if not skip_derivs: + dphi[inext] = g * dphi[icur] + a * phi[icur] * dz + # general r by recurrence for r in range(1, n - p - q): - icur = idx(p, q, r) + iprev, icur = icur, inext inext = idx(p, q, r + 1) - iprev = idx(p, q, r - 1) ar, br, cr = jrc(2 * p + 2 * q + 2, 0, r) - phi[inext] = (ar * z + br) * phi[icur] - cr * phi[iprev] + g = ar * z + br + phi[inext] = g * phi[icur] - cr * phi[iprev] if skip_derivs: continue - dphi[inext] = (ar * z + br) * dphi[icur] + ar * phi[icur] * dz - cr * dphi[iprev] + dphi[inext] = g * dphi[icur] + ar * phi[icur] * dz - cr * dphi[iprev] def _tabulate_dpts(tabulator, D, n, order, pts):