diff --git a/zernipy/grid.py b/zernipy/grid.py index 9b4131e..a6c2a57 100644 --- a/zernipy/grid.py +++ b/zernipy/grid.py @@ -1,7 +1,8 @@ """Classes for representing flux coordinates.""" from abc import ABC, abstractmethod -from zernipy.backend import np, jnp, put + +from zernipy.backend import jnp, np, put class _Indexable: diff --git a/zernipy/zernike.py b/zernipy/zernike.py index 9f69a32..9ee989d 100644 --- a/zernipy/zernike.py +++ b/zernipy/zernike.py @@ -1,7 +1,8 @@ """Functions for evaluating Zernike polynomials and their derivatives.""" import functools -from zernipy.backend import cond, custom_jvp, fori_loop, gammaln, jit, jnp, select, switch + +from zernipy.backend import cond, custom_jvp, fori_loop, gammaln, jit, jnp, switch def jacobi_poly_single(x, n, alpha, beta=0, P_n1=0, P_n2=0): @@ -1197,7 +1198,7 @@ def find_init_jacobi(dx, args): @custom_jvp -@functools.partial(jit, static_argnums=[3,4]) +@functools.partial(jit, static_argnums=[3, 4]) def zernike_radial_jvp_gpu(r, l, m, dr=0, repeat=1): """Radial part of zernike polynomials. @@ -1229,12 +1230,13 @@ def zernike_radial_jvp_gpu(r, l, m, dr=0, repeat=1): "Analytic radial derivatives of Zernike polynomials for order>4 " + "have not been implemented." ) + def update(x, args): index, result, out = args idx = index[x] out = out.at[:, idx].set(result[:, None]) return (index, result, out) - + def body_inner(N, args): alpha, out, P_past = args P_n2 = P_past[0] @@ -1299,9 +1301,9 @@ def find_inter_jacobi(dx, args): - coef[3] * 128 * (2 * alpha + 3) * r ** (alpha + 2) * P_n[3] + coef[4] * 256 * r ** (alpha + 4) * P_n[4] ) - index = jnp.argwhere(jnp.logical_and(m == alpha, n == N), size=repeat) + index = jnp.argwhere(jnp.logical_and(m == alpha, n == N), size=repeat) _, _, out = fori_loop(0, repeat, update, (index, result, out)) - + # Shift past values if needed mask = N >= 2 + dxs P_n2 = jnp.where(mask[:, None], P_n1, P_n2)