Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma committed Mar 5, 2024
1 parent 33c1029 commit 3bd6b98
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
3 changes: 2 additions & 1 deletion zernipy/grid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Classes for representing flux coordinates."""

from abc import ABC, abstractmethod
from zernipy.backend import np, jnp, put

from zernipy.backend import jnp, np, put


class _Indexable:
Expand Down
12 changes: 7 additions & 5 deletions zernipy/zernike.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Functions for evaluating Zernike polynomials and their derivatives."""

import functools
from zernipy.backend import cond, custom_jvp, fori_loop, gammaln, jit, jnp, select, switch

from zernipy.backend import cond, custom_jvp, fori_loop, gammaln, jit, jnp, switch


def jacobi_poly_single(x, n, alpha, beta=0, P_n1=0, P_n2=0):
Expand Down Expand Up @@ -1197,7 +1198,7 @@ def find_init_jacobi(dx, args):


@custom_jvp
@functools.partial(jit, static_argnums=[3,4])
@functools.partial(jit, static_argnums=[3, 4])
def zernike_radial_jvp_gpu(r, l, m, dr=0, repeat=1):
"""Radial part of zernike polynomials.
Expand Down Expand Up @@ -1229,12 +1230,13 @@ def zernike_radial_jvp_gpu(r, l, m, dr=0, repeat=1):
"Analytic radial derivatives of Zernike polynomials for order>4 "
+ "have not been implemented."
)

def update(x, args):
index, result, out = args
idx = index[x]
out = out.at[:, idx].set(result[:, None])
return (index, result, out)

def body_inner(N, args):
alpha, out, P_past = args
P_n2 = P_past[0]
Expand Down Expand Up @@ -1299,9 +1301,9 @@ def find_inter_jacobi(dx, args):
- coef[3] * 128 * (2 * alpha + 3) * r ** (alpha + 2) * P_n[3]
+ coef[4] * 256 * r ** (alpha + 4) * P_n[4]
)
index = jnp.argwhere(jnp.logical_and(m == alpha, n == N), size=repeat)
index = jnp.argwhere(jnp.logical_and(m == alpha, n == N), size=repeat)
_, _, out = fori_loop(0, repeat, update, (index, result, out))

# Shift past values if needed
mask = N >= 2 + dxs
P_n2 = jnp.where(mask[:, None], P_n1, P_n2)
Expand Down

0 comments on commit 3bd6b98

Please sign in to comment.