From 33c1029c9897be3b7554109ed07569cbfe993be1 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 4 Mar 2024 20:38:02 -0500 Subject: [PATCH] remove extra functions --- notebooks/benchmarks.ipynb | 2 +- zernipy/zernike.py | 313 ------------------------------------- 2 files changed, 1 insertion(+), 314 deletions(-) diff --git a/notebooks/benchmarks.ipynb b/notebooks/benchmarks.ipynb index 2a773eb..7f0bbb5 100644 --- a/notebooks/benchmarks.ipynb +++ b/notebooks/benchmarks.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"markdown","metadata":{},"source":["# Comparison of different available functions"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[],"source":["import sys\n","import os\n","\n","sys.path.insert(0, os.path.abspath(\".\"))\n","sys.path.append(os.path.abspath(\"../\"))"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["# from zernipy import set_device\n","# set_device(\"gpu\")"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["using JAX backend, jax version=0.4.14, jaxlib version=0.4.14, dtype=float64\n","Using device: CPU, with 20.36 GB available memory\n"]}],"source":["import numpy as np\n","from zernipy.zernike import *\n","from zernipy.basis import ZernikePolynomial, FourierZernikeBasis"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["basis = ZernikePolynomial(L=50, M=50, spectral_indexing=\"ansi\", sym=False)\n","basis = FourierZernikeBasis(L=12, M=12, N=12)\n","r = np.linspace(0, 1, 100)"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 0\n","# With all the checks necessary and reverse mode AutoDiff capable\n","17.5 ms ± 126 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n","2.44 ms ± 42.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n","# With all the checks necessary but less efficient\n","4.51 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"]}],"source":["dr = 0\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","# print(\"# With no checks (full set of modes)\")\n","# %timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","# print(\"# With no duplicate modes (might have lacking modes)\")\n","# %timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","# print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","# %timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","# %timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","# %timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","# %timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","# %timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp_gpu(r, basis.modes[:,0], basis.modes[:,1], dr, repeat=13).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","# %timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 1\n","# With no checks (full set of modes)\n","1.52 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","1.59 ms ± 725 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","2.59 ms ± 646 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.55 ms ± 550 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","3.59 ms ± 9.92 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","3.58 ms ± 7.23 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","6.31 ms ± 112 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","11 ms ± 626 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 1\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 2\n","# With no checks (full set of modes)\n","1.7 ms ± 705 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","1.88 ms ± 695 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","2.7 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.75 ms ± 640 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","2.63 ms ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.64 ms ± 33 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","6.87 ms ± 50 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","14.9 ms ± 838 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 2\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 3\n","# With no checks (full set of modes)\n","1.93 ms ± 795 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","2.07 ms ± 713 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","2.87 ms ± 736 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","3.95 ms ± 853 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","2.69 ms ± 19.4 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.67 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","7.51 ms ± 98.4 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","20.8 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 3\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 4\n","# With no checks (full set of modes)\n","2.4 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","2.42 ms ± 820 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","4.91 ms ± 781 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","4.49 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","5.1 ms ± 95.6 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","5.28 ms ± 171 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","8.52 ms ± 77.1 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","28.1 ms ± 3.5 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 4\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"kernelspec":{"display_name":"desc-env","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"}},"nbformat":4,"nbformat_minor":2} +{"cells":[{"cell_type":"markdown","metadata":{},"source":["# Comparison of different available functions"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[],"source":["import sys\n","import os\n","\n","sys.path.insert(0, os.path.abspath(\".\"))\n","sys.path.append(os.path.abspath(\"../\"))"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["# from zernipy import set_device\n","# set_device(\"gpu\")"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["using JAX backend, jax version=0.4.14, jaxlib version=0.4.14, dtype=float64\n","Using device: CPU, with 20.36 GB available memory\n"]}],"source":["import numpy as np\n","from zernipy.zernike import *\n","from zernipy.basis import ZernikePolynomial, FourierZernikeBasis"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["basis = ZernikePolynomial(L=50, M=50, spectral_indexing=\"ansi\", sym=False)\n","basis = FourierZernikeBasis(L=12, M=12, N=12)\n","r = np.linspace(0, 1, 100)"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 0\n","# With all the checks necessary and reverse mode AutoDiff capable\n","17.5 ms ± 126 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n","2.44 ms ± 42.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n","# With all the checks necessary but less efficient\n","4.51 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"]}],"source":["dr = 0\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp_gpu(r, basis.modes[:,0], basis.modes[:,1], dr, repeat=13).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 1\n","# With no checks (full set of modes)\n","1.52 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","1.59 ms ± 725 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","2.59 ms ± 646 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.55 ms ± 550 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","3.59 ms ± 9.92 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","3.58 ms ± 7.23 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","6.31 ms ± 112 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","11 ms ± 626 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 1\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 2\n","# With no checks (full set of modes)\n","1.7 ms ± 705 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","1.88 ms ± 695 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","2.7 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.75 ms ± 640 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","2.63 ms ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.64 ms ± 33 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","6.87 ms ± 50 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","14.9 ms ± 838 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 2\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 3\n","# With no checks (full set of modes)\n","1.93 ms ± 795 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","2.07 ms ± 713 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","2.87 ms ± 736 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","3.95 ms ± 853 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","2.69 ms ± 19.4 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.67 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","7.51 ms ± 98.4 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","20.8 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 3\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 4\n","# With no checks (full set of modes)\n","2.4 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","2.42 ms ± 820 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","4.91 ms ± 781 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","4.49 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","5.1 ms ± 95.6 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","5.28 ms ± 171 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","8.52 ms ± 77.1 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","28.1 ms ± 3.5 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 4\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"kernelspec":{"display_name":"desc-env","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"}},"nbformat":4,"nbformat_minor":2} diff --git a/zernipy/zernike.py b/zernipy/zernike.py index 7ca316b..9f69a32 100644 --- a/zernipy/zernike.py +++ b/zernipy/zernike.py @@ -1,7 +1,6 @@ """Functions for evaluating Zernike polynomials and their derivatives.""" import functools - from zernipy.backend import cond, custom_jvp, fori_loop, gammaln, jit, jnp, select, switch @@ -704,59 +703,6 @@ def body(alpha, args): return out -@custom_jvp -@jit -def zernike_radial_rory(r, l, m, dr=0): - """Radial part of zernike polynomials. - - Calculates Radial part of Zernike Polynomials using Jacobi recursion relation - by getting rid of the redundant calculations for appropriate modes. This version - is almost the same as zernike_radial_old function but way faster and more - accurate. First version of this function is zernike_radial_separate which has - many function for each derivative definition. User can refer that for clarity. - - There was even faster version of this code but that doesn't have checks - for duplicate modes - - # Find the index corresponding to the original array - # I changed arange function to get rid of 0 as index confusion - # so if index is full of 0s, there is no such mode - # (FAST BUT NEED A CHECK FOR DUPLICATE MODES) - index = jnp.where( - jnp.logical_and(m == alpha, n == N), - jnp.arange(1, m.size + 1), - 0, - ) - idx = jnp.sum(index) - # needed for proper index - idx -= 1 - result = (-1) ** N * r**alpha * P_n - out = out.at[:, idx].set(jnp.where(idx >= 0, result, out.at[:, idx].get())) - - Above part replaces the matrix update conducted by following code, - - _, _, _, out = fori_loop(0, m.size, update, (alpha, N, result, out)) - - Parameters - ---------- - r : ndarray, shape(N,) - radial coordinates to evaluate basis - l : ndarray of int, shape(K,) - radial mode number(s) - m : ndarray of int, shape(K,) - azimuthal mode number(s) - dr : int - order of derivative (Default = 0) - - Returns - ------- - out : ndarray, shape(N,K) - basis function(s) evaluated at specified points - - """ - return _zernike_radial_vectorized_rory(r, l, m, dr) - - @functools.partial(jit, static_argnums=3) def zernike_radial(r, l, m, dr=0): """Radial part of zernike polynomials. @@ -1030,76 +976,6 @@ def zernike_radial_old_desc(r, l, m, dr=0): return s * jnp.where((l - m) % 2 == 0, out, 0) -@custom_jvp -@jit -def zernike_radial_if_switch(r, l, m, dr=0): - """Radial part of zernike polynomials. - - Calculates Radial part of Zernike Polynomials using Jacobi recursion relation - by getting rid of the redundant calculations for appropriate modes. - https://en.wikipedia.org/wiki/Jacobi_polynomials#Recurrence_relations - - For the derivatives, the following formula is used with above recursion relation, - https://en.wikipedia.org/wiki/Jacobi_polynomials#Derivatives - - Used formulas are also in the zerike_eval.ipynb notebook in docs. - - This function can be made faster. However, JAX reverse mode AD causes problems. - In future, we may use vmap() instead of jnp.vectorize() to be able to set dr as - static argument, and not calculate every derivative even thoguh not asked. - - Parameters - ---------- - r : ndarray, shape(N,) or scalar - radial coordinates to evaluate basis - l : ndarray of int, shape(K,) or integer - radial mode number(s) - m : ndarray of int, shape(K,) or integer - azimuthal mode number(s) - dr : int - order of derivative (Default = 0) - - Returns - ------- - out : ndarray, shape(N,K) - basis function(s) evaluated at specified points - - """ - dr = jnp.asarray(dr).astype(int) - - def ZeroOne(args): - def Zero(args): - (r, l, m, dr) = args - return _zernike_radial_vectorized(r, l, m, dr) - - def One(args): - (r, l, m, dr) = args - return _zernike_radial_vectorized_d1(r, l, m, dr) - - return cond(dr == 0, Zero, One, (r, l, m, dr)) - - def Rest(args): - def Two(args): - (r, l, m, dr) = args - return _zernike_radial_vectorized_d2(r, l, m, dr) - - def ThreeFour(args): - def Three(args): - (r, l, m, dr) = args - return _zernike_radial_vectorized_d3(r, l, m, dr) - - def Four(args): - (r, l, m, dr) = args - return _zernike_radial_vectorized_d4(r, l, m, dr) - - return cond(dr == 3, Three, Four, (r, l, m, dr)) - - return cond(dr == 2, Two, ThreeFour, (r, l, m, dr)) - - # Switch doesn't work. Under JIT, only viable option seems this conditional - return cond(dr < 2, ZeroOne, Rest, (r, l, m, dr)) - - @custom_jvp @jit def zernike_radial_switch(r, l, m, dr=0): @@ -2021,169 +1897,6 @@ def body(alpha, out): return out -@functools.partial(jnp.vectorize, excluded=(1, 2, 3), signature="()->(k)") -def _zernike_radial_vectorized_rory(r, l, m, dr): - """Calculation of Radial part of Zernike polynomials.""" - - def update(i, args): - alpha, N, result, out = args - idx = jnp.where(jnp.logical_and(m[i] == alpha, n[i] == N), i, -1) - - def falseFun(args): - _, _, out = args - return out - - def trueFun(args): - idx, result, out = args - out = out.at[idx].set(result) - return out - - out = cond(idx >= 0, trueFun, falseFun, (idx, result, out)) - return (alpha, N, result, out) - - def body_inner(N, args): - alpha, out, P_past = args - P_n2 = P_past[0] # Jacobi at N-2 - P_n1 = P_past[1] # Jacobi at N-1 - P_n = jnp.zeros(MAXDR + 1) # Jacobi at N - - def find_inter_jacobi(dx, args): - N, alpha, P_n1, P_n2, P_n = args - P_n = P_n.at[dx].set( - jacobi_poly_single(r_jacobi, N - dx, alpha + dx, dx, P_n1[dx], P_n2[dx]) - ) - return (N, alpha, P_n1, P_n2, P_n) - - # Calculate Jacobi polynomial and derivatives for (alpha,N) - _, _, _, _, P_n = fori_loop( - 0, MAXDR + 1, find_inter_jacobi, (N, alpha, P_n1, P_n2, P_n) - ) - - # Calculate coefficients for derivatives. coef[0] will never be used. Jax - # doesn't have Gamma function directly, that's why we calculate Logarithm of - # Gamma function and then exponentiate it. - coef = jnp.exp( - gammaln(alpha + N + 1 + dxs) - dxs * jnp.log(2) - gammaln(alpha + N + 1) - ) - - # Since we cannot make dr static, we cannot use if statement. Instead define - # functions and execute only the function at dr th index - - # 0th Derivative of Zernike Radial - branch0 = lambda x: (-1) ** N * x**alpha * P_n[0] - # 1th Derivative of Zernike Radial - branch1 = lambda x: (-1) ** N * ( - alpha * x ** jnp.maximum(alpha - 1, 0) * P_n[0] - - coef[1] * 4 * x ** (alpha + 1) * P_n[1] - ) - # 2nd Derivative of Zernike Radial - branch2 = lambda x: (-1) ** N * ( - (alpha - 1) * alpha * x ** jnp.maximum(alpha - 2, 0) * P_n[0] - - coef[1] * 4 * (2 * alpha + 1) * x**alpha * P_n[1] - + coef[2] * 16 * x ** (alpha + 2) * P_n[2] - ) - # 3rd Derivative of Zernike Radial - branch3 = lambda x: (-1) ** N * ( - (alpha - 2) * (alpha - 1) * alpha * x ** jnp.maximum(alpha - 3, 0) * P_n[0] - - coef[1] * 12 * alpha**2 * x ** jnp.maximum(alpha - 1, 0) * P_n[1] - + coef[2] * 48 * (alpha + 1) * x ** (alpha + 1) * P_n[2] - - coef[3] * 64 * x ** (alpha + 3) * P_n[3] - ) - # 4th Derivative of Zernike Radial - branch4 = lambda x: (-1) ** N * ( - (alpha - 3) - * (alpha - 2) - * (alpha - 1) - * alpha - * x ** jnp.maximum(alpha - 4, 0) - * P_n[0] - - coef[1] - * 8 - * alpha - * (2 * alpha**2 - 3 * alpha + 1) - * x ** jnp.maximum(alpha - 2, 0) - * P_n[1] - + coef[2] * 48 * (2 * alpha**2 + 2 * alpha + 1) * x**alpha * P_n[2] - - coef[3] * 128 * (2 * alpha + 3) * x ** (alpha + 2) * P_n[3] - + coef[4] * 256 * x ** (alpha + 4) * P_n[4] - ) - # if dr is greater than 4, this will be executed - branch5 = lambda x: jnp.nan - branches = [branch0, branch1, branch2, branch3, branch4, branch5] - # Only calculate the function at dr th index with input r - result = switch(dr, branches, r) - # Check if the calculated values is in the given modes - _, _, _, out = fori_loop(0, m.size, update, (alpha, N, result, out)) - - # Shift past values if needed - # For derivative order dx, if N is smaller than 2+dx, then only the initial - # value calculated by find_init_jacobi function will be used. So, if you update - # P_n's, preceeding values will be wrong. - mask = N >= 2 + dxs - P_n2 = jnp.where(mask, P_n1, P_n2) - P_n1 = jnp.where(mask, P_n, P_n1) - # Form updated P_past matrix - P_past = P_past.at[0, :].set(P_n2) - P_past = P_past.at[1, :].set(P_n1) - - return (alpha, out, P_past) - - def body(alpha, out): - # find l values with m values equal to alpha - l_alpha = jnp.where(m == alpha, l, 0) - # find the maximum among them - L_max = jnp.max(l_alpha) - # Maximum possible value for n for loop bound - N_max = (L_max - alpha) // 2 - - # Find initial values of Jacobi Polynomial and derivatives - def find_init_jacobi(dx, args): - alpha, P_past = args - # Jacobi for n=0 - P_past = P_past.at[0, dx].set( - jacobi_poly_single(r_jacobi, 0, alpha + dx, beta=dx) - ) - # Jacobi for n=1 - P_past = P_past.at[1, dx].set( - jacobi_poly_single(r_jacobi, 1, alpha + dx, beta=dx) - ) - return (alpha, P_past) - - # First 2 Jacobi Polynomials (they don't need recursion) - # P_past stores last 2 Jacobi polynomials (and required derivatives) - # evaluated at given r points - P_past = jnp.zeros((2, MAXDR + 1)) - _, P_past = fori_loop(0, MAXDR + 1, find_init_jacobi, (alpha, P_past)) - - # Loop over every n value - _, out, _ = fori_loop( - 0, (N_max + 1).astype(int), body_inner, (alpha, out, P_past) - ) - return out - - # Make inputs 1D arrays in case they aren't - m = jnp.atleast_1d(m) - l = jnp.atleast_1d(l) - dr = jnp.asarray(dr).astype(int) - - # From the vectorization, the overall output will be (r.size, m.size) - out = jnp.zeros(m.size) - r_jacobi = 1 - 2 * r**2 - m = jnp.abs(m) - n = ((l - m) // 2).astype(int) - - # This part can be better implemented. Try to make dr as static argument - # jnp.vectorize doesn't allow it to be static - MAXDR = 4 - dxs = jnp.arange(0, MAXDR + 1) - - M_max = jnp.max(m) - # Loop over every different m value. There is another nested - # loop which will execute necessary n values. - out = fori_loop(0, (M_max + 1).astype(int), body, (out)) - return out - - @functools.partial(jnp.vectorize, excluded=(1, 2, 3), signature="()->(k)") def _zernike_radial_vectorized_gpu(r, l, m, dr): """Calculation of Radial part of Zernike polynomials.""" @@ -2763,19 +2476,6 @@ def _binom_body_fun(i, b_n): return b -@zernike_radial_rory.defjvp -def _zernike_radial_rory_jvp(x, xdot): - (r, l, m, dr) = x - (rdot, ldot, mdot, drdot) = xdot - f = zernike_radial_rory(r, l, m, dr) - df = zernike_radial_rory(r, l, m, dr + 1) - # in theory l, m, dr aren't differentiable (they're integers) - # but marking them as non-diff argnums seems to cause escaped tracer values. - # probably a more elegant fix, but just setting those derivatives to zero seems - # to work fine. - return f, (df.T * rdot).T + 0 * ldot + 0 * mdot + 0 * drdot - - @zernike_radial_jvp.defjvp def _zernike_radial_jvp(x, xdot): (r, l, m, dr) = x @@ -2828,19 +2528,6 @@ def _zernike_radial_switch_jvp(x, xdot): return f, (df.T * rdot).T + 0 * ldot + 0 * mdot + 0 * drdot -@zernike_radial_if_switch.defjvp -def _zernike_radial_if_switch_jvp(x, xdot): - (r, l, m, dr) = x - (rdot, ldot, mdot, drdot) = xdot - f = zernike_radial_if_switch(r, l, m, dr) - df = zernike_radial_if_switch(r, l, m, dr + 1) - # in theory l, m, dr aren't differentiable (they're integers) - # but marking them as non-diff argnums seems to cause escaped tracer values. - # probably a more elegant fix, but just setting those derivatives to zero seems - # to work fine. - return f, (df.T * rdot).T + 0 * ldot + 0 * mdot + 0 * drdot - - @_jacobi.defjvp def _jacobi_jvp(x, xdot): (n, alpha, beta, x, dx) = x