Skip to content

Commit

Permalink
update with desc code
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma committed May 27, 2024
1 parent 51fc1e9 commit ec66d46
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions zernipax/zernike.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Functions for evaluating Zernike polynomials and their derivatives."""

import functools
from math import factorial

import mpmath

Expand Down Expand Up @@ -2056,15 +2057,13 @@ def polyder_vec(p, m, exact=False):


def _polyder_exact(p, m):
from scipy.special import factorial

m = np.asarray(m, dtype=int) # order of derivative
p = np.atleast_2d(p)
order = p.shape[1] - 1

D = np.arange(order, -1, -1)
num = np.array([factorial(i, exact=True) for i in D], dtype=object)
den = np.array([factorial(max(i - m, 0), exact=True) for i in D], dtype=object)
num = np.array([factorial(i) for i in D], dtype=object)
den = np.array([factorial(max(i - m, 0)) for i in D], dtype=object)
D = (num // den).astype(p.dtype)

p = np.roll(D * p, m, axis=1)
Expand Down Expand Up @@ -2182,6 +2181,8 @@ def zernike_radial_coeffs(l, m, exact=True):
Integer representation is exact up to l~54, so leaving `exact` arg as False
can speed up evaluation with no loss in accuracy
"""
from decimal import Decimal, getcontext

l = np.atleast_1d(l).astype(int)
m = np.atleast_1d(np.abs(m)).astype(int)
lm = np.vstack([l, m]).T
Expand All @@ -2190,13 +2191,11 @@ def zernike_radial_coeffs(l, m, exact=True):
lms, idx = np.unique(lm, return_inverse=True, axis=0)

if exact:
from scipy.special import factorial

_factorial = lambda x: factorial(x, exact=True)
# Increase the precision of Decimal operations
getcontext().prec = 100
else:
from math import factorial

_factorial = factorial
# Use lower precision for not exact calculations
getcontext().prec = 15
npoly = len(lms)
lmax = np.max(lms[:, 0])
coeffs = np.zeros((npoly, lmax + 1), dtype=object)
Expand All @@ -2205,13 +2204,13 @@ def zernike_radial_coeffs(l, m, exact=True):
ll = lms[ii, 0]
mm = lms[ii, 1]
for s in range(mm, ll + 1, 2):
coeffs[ii, s] = (
(-1) ** ((ll - s) // 2)
* _factorial((ll + s) // 2)
/ (
_factorial((ll - s) // 2)
* _factorial((s + mm) // 2)
* _factorial((s - mm) // 2)
coeffs[ii, s] = Decimal(
int((-1) ** ((ll - s) // 2) * factorial((ll + s) // 2))
) / Decimal(
int(
factorial((ll - s) // 2)
* factorial((s + mm) // 2)
* factorial((s - mm) // 2)
)
)
c = np.fliplr(np.where(lm_even, coeffs, 0))
Expand Down

0 comments on commit ec66d46

Please sign in to comment.