From c2860e6cc826c5e0450c6e70e11ca6e789fbcdb2 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Sun, 25 Feb 2024 02:20:40 -0500 Subject: [PATCH] add jvp_gpu --- notebooks/benchmarks.ipynb | 2 +- zernipy/zernike.py | 188 ++++++++++++++++++++++++++++++++++++- 2 files changed, 188 insertions(+), 2 deletions(-) diff --git a/notebooks/benchmarks.ipynb b/notebooks/benchmarks.ipynb index 38f9908..802625e 100644 --- a/notebooks/benchmarks.ipynb +++ b/notebooks/benchmarks.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"markdown","metadata":{},"source":["# Comparison of different available functions"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[],"source":["import sys\n","import os\n","\n","sys.path.insert(0, os.path.abspath(\".\"))\n","sys.path.append(os.path.abspath(\"../\"))"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["# from zernipy import set_device\n","# set_device(\"gpu\")"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["using JAX backend, jax version=0.4.14, jaxlib version=0.4.14, dtype=float64\n","Using device: CPU, with 21.52 GB available memory\n"]}],"source":["import numpy as np\n","from zernipy.zernike import *\n","from zernipy.basis import ZernikePolynomial, FourierZernikeBasis"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["basis = ZernikePolynomial(L=50, M=50, spectral_indexing=\"ansi\", sym=False)\n","basis = FourierZernikeBasis(L=12, M=12, N=12)\n","r = np.linspace(0, 1, 100)"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 0\n","# With no checks (full set of modes)\n","867 µs ± 8.11 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n","# With no duplicate modes (might have lacking modes)\n","952 µs ± 6.44 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","1.76 ms ± 29.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n","1.74 ms ± 17.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","2.16 ms ± 23.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n","2.13 ms ± 28.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n","17.8 ms ± 119 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n","2.1 ms ± 25.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n","# With all the checks necessary but less efficient\n","2.76 ms ± 48.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n","4.75 ms ± 168 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"]}],"source":["dr = 0\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 1\n","# With no checks (full set of modes)\n","1.52 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","1.59 ms ± 725 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","2.59 ms ± 646 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.55 ms ± 550 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","3.59 ms ± 9.92 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","3.58 ms ± 7.23 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","6.31 ms ± 112 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","11 ms ± 626 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 1\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 2\n","# With no checks (full set of modes)\n","1.7 ms ± 705 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","1.88 ms ± 695 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","2.7 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.75 ms ± 640 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","2.63 ms ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.64 ms ± 33 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","6.87 ms ± 50 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","14.9 ms ± 838 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 2\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 3\n","# With no checks (full set of modes)\n","1.93 ms ± 795 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","2.07 ms ± 713 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","2.87 ms ± 736 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","3.95 ms ± 853 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","2.69 ms ± 19.4 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.67 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","7.51 ms ± 98.4 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","20.8 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 3\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 4\n","# With no checks (full set of modes)\n","2.4 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","2.42 ms ± 820 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","4.91 ms ± 781 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","4.49 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","5.1 ms ± 95.6 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","5.28 ms ± 171 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","8.52 ms ± 77.1 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","28.1 ms ± 3.5 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 4\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"kernelspec":{"display_name":"desc-env","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"}},"nbformat":4,"nbformat_minor":2} +{"cells":[{"cell_type":"markdown","metadata":{},"source":["# Comparison of different available functions"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[],"source":["import sys\n","import os\n","\n","sys.path.insert(0, os.path.abspath(\".\"))\n","sys.path.append(os.path.abspath(\"../\"))"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["# from zernipy import set_device\n","# set_device(\"gpu\")"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["using JAX backend, jax version=0.4.14, jaxlib version=0.4.14, dtype=float64\n","Using device: CPU, with 21.13 GB available memory\n"]}],"source":["import numpy as np\n","from zernipy.zernike import *\n","from zernipy.basis import ZernikePolynomial, FourierZernikeBasis"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["basis = ZernikePolynomial(L=50, M=50, spectral_indexing=\"ansi\", sym=False)\n","basis = FourierZernikeBasis(L=12, M=12, N=12)\n","r = np.linspace(0, 1, 100)"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 0"]},{"name":"stdout","output_type":"stream","text":["\n","# With no checks (full set of modes)\n","842 µs ± 9.18 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n","# With no duplicate modes (might have lacking modes)\n","849 µs ± 8.46 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","1.72 ms ± 33.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n","1.72 ms ± 30.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","2.12 ms ± 20.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n","# With all the checks necessary but less efficient\n","2.58 ms ± 54.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n","4.46 ms ± 170 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"]}],"source":["dr = 0\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","# %timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","# %timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","# %timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","# %timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 1\n","# With no checks (full set of modes)\n","1.52 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","1.59 ms ± 725 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","2.59 ms ± 646 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.55 ms ± 550 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","3.59 ms ± 9.92 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","3.58 ms ± 7.23 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","6.31 ms ± 112 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","11 ms ± 626 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 1\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 2\n","# With no checks (full set of modes)\n","1.7 ms ± 705 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","1.88 ms ± 695 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","2.7 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.75 ms ± 640 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","2.63 ms ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.64 ms ± 33 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","6.87 ms ± 50 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","14.9 ms ± 838 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 2\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 3\n","# With no checks (full set of modes)\n","1.93 ms ± 795 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","2.07 ms ± 713 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","2.87 ms ± 736 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","3.95 ms ± 853 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","2.69 ms ± 19.4 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","2.67 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","7.51 ms ± 98.4 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","20.8 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 3\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["zernike_radial, derivative order: 4\n","# With no checks (full set of modes)\n","2.4 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With no duplicate modes (might have lacking modes)\n","2.42 ms ± 820 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but no reverse mode AutoDiff capable\n","4.91 ms ± 781 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","4.49 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary and reverse mode AutoDiff capable\n","5.1 ms ± 95.6 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","5.28 ms ± 171 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","# With all the checks necessary but less efficient\n","8.52 ms ± 77.1 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)\n","28.1 ms ± 3.5 ms per loop (mean ± std. dev. of 7 runs, 200 loops each)\n"]}],"source":["dr = 4\n","print(f\"zernike_radial, derivative order: {dr}\")\n","\n","print(\"# With no checks (full set of modes)\")\n","%timeit _ = zernike_radial_no_check(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With no duplicate modes (might have lacking modes)\")\n","%timeit _ = zernike_radial_no_duplicate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but no reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_separate(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary and reverse mode AutoDiff capable\")\n","%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_if_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","\n","print(\"# With all the checks necessary but less efficient\")\n","%timeit _ = zernike_radial_rory(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()\n","%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"kernelspec":{"display_name":"desc-env","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"}},"nbformat":4,"nbformat_minor":2} diff --git a/zernipy/zernike.py b/zernipy/zernike.py index 48bf451..b5510a6 100644 --- a/zernipy/zernike.py +++ b/zernipy/zernike.py @@ -2,7 +2,7 @@ import functools -from zernipy.backend import cond, custom_jvp, fori_loop, gammaln, jit, jnp, switch +from zernipy.backend import cond, custom_jvp, fori_loop, gammaln, jax, jit, jnp, scan, switch def jacobi_poly_single(x, n, alpha, beta=0, P_n1=0, P_n2=0): @@ -1320,6 +1320,179 @@ def find_init_jacobi(dx, args): return out +@custom_jvp +@functools.partial(jit, static_argnums=3) +def zernike_radial_jvp_gpu(r, l, m, dr=0): + """Radial part of zernike polynomials. + + Calculates Radial part of Zernike Polynomials using Jacobi recursion relation + by getting rid of the redundant calculations for appropriate modes. This version + is almost the same as zernike_radial_old function but way faster and more + accurate. First version of this function is zernike_radial_separate which has + many function for each derivative definition. User can refer that for clarity. + + Parameters + ---------- + r : ndarray, shape(N,) + radial coordinates to evaluate basis + l : ndarray of int, shape(K,) + radial mode number(s) + m : ndarray of int, shape(K,) + azimuthal mode number(s) + dr : int + order of derivative (Default = 0) + + Returns + ------- + out : ndarray, shape(N,K) + basis function(s) evaluated at specified points + + """ + if dr > 4: + raise NotImplementedError( + "Analytic radial derivatives of Zernike polynomials for order>4 " + + "have not been implemented." + ) + + def update(args, x): + alpha, N, result, out = args + idx = jnp.where(jnp.logical_and(m[x] == alpha, n[x] == N), x, -1) + + def falseFun(args): + _, _, out = args + return out + + def trueFun(args): + idx, result, out = args + out = out.at[:, idx].set(result) + return out + + out = cond(idx >= 0, trueFun, falseFun, (idx, result, out)) + return (alpha, N, result, out), None + + def body_inner(N, args): + alpha, out, P_past = args + P_n2 = P_past[0] + P_n1 = P_past[1] + P_n = jnp.zeros((dr + 1, r.size)) + + def find_inter_jacobi(dx, args): + N, alpha, P_n1, P_n2, P_n = args + P_n = P_n.at[dx, :].set( + jacobi_poly_single(r_jacobi, N - dx, alpha + dx, dx, P_n1[dx], P_n2[dx]) + ) + return (N, alpha, P_n1, P_n2, P_n) + + # Calculate Jacobi polynomial and derivatives for (m,n) + _, _, _, _, P_n = fori_loop( + 0, dr + 1, find_inter_jacobi, (N, alpha, P_n1, P_n2, P_n) + ) + + coef = jnp.exp( + gammaln(alpha + N + 1 + dxs) - dxs * jnp.log(2) - gammaln(alpha + N + 1) + ) + # TODO: A version without if statements are possible? + if dr == 0: + result = (-1) ** N * r**alpha * P_n[0] + elif dr == 1: + result = (-1) ** N * ( + alpha * r ** jnp.maximum(alpha - 1, 0) * P_n[0] + - coef[1] * 4 * r ** (alpha + 1) * P_n[1] + ) + elif dr == 2: + result = (-1) ** N * ( + (alpha - 1) * alpha * r ** jnp.maximum(alpha - 2, 0) * P_n[0] + - coef[1] * 4 * (2 * alpha + 1) * r**alpha * P_n[1] + + coef[2] * 16 * r ** (alpha + 2) * P_n[2] + ) + elif dr == 3: + result = (-1) ** N * ( + (alpha - 2) + * (alpha - 1) + * alpha + * r ** jnp.maximum(alpha - 3, 0) + * P_n[0] + - coef[1] * 12 * alpha**2 * r ** jnp.maximum(alpha - 1, 0) * P_n[1] + + coef[2] * 48 * (alpha + 1) * r ** (alpha + 1) * P_n[2] + - coef[3] * 64 * r ** (alpha + 3) * P_n[3] + ) + elif dr == 4: + result = (-1) ** N * ( + (alpha - 3) + * (alpha - 2) + * (alpha - 1) + * alpha + * r ** jnp.maximum(alpha - 4, 0) + * P_n[0] + - coef[1] + * 8 + * alpha + * (2 * alpha**2 - 3 * alpha + 1) + * r ** jnp.maximum(alpha - 2, 0) + * P_n[1] + + coef[2] * 48 * (2 * alpha**2 + 2 * alpha + 1) * r**alpha * P_n[2] + - coef[3] * 128 * (2 * alpha + 3) * r ** (alpha + 2) * P_n[3] + + coef[4] * 256 * r ** (alpha + 4) * P_n[4] + ) + (_, _, _, out), _ = scan(update, (alpha, N, result, out), jnp.arange(m.size)) + + # Shift past values if needed + mask = N >= 2 + dxs + P_n2 = jnp.where(mask[:, None], P_n1, P_n2) + P_n1 = jnp.where(mask[:, None], P_n, P_n1) + P_past = P_past.at[0, :, :].set(P_n2) + P_past = P_past.at[1, :, :].set(P_n1) + + return (alpha, out, P_past) + + def body(alpha, out): + # find l values with m values equal to alpha + l_alpha = jnp.where(m == alpha, l, 0) + # find the maximum among them + L_max = jnp.max(l_alpha) + # Maximum possible value for n for loop bound + N_max = (L_max - alpha) // 2 + + def find_init_jacobi(dx, args): + alpha, P_past = args + P_past = P_past.at[0, dx, :].set( + jacobi_poly_single(r_jacobi, 0, alpha + dx, beta=dx) + ) + P_past = P_past.at[1, dx, :].set( + jacobi_poly_single(r_jacobi, 1, alpha + dx, beta=dx) + ) + return (alpha, P_past) + + # First 2 Jacobi Polynomials (they don't need recursion) + # P_past stores last 2 Jacobi polynomials (and required derivatives) + # evaluated at given r points + P_past = jnp.zeros((2, dr + 1, r.size)) + _, P_past = fori_loop(0, dr + 1, find_init_jacobi, (alpha, P_past)) + + # Loop over every n value + _, out, _ = fori_loop( + 0, (N_max + 1).astype(int), body_inner, (alpha, out, P_past) + ) + return out + + r = jnp.atleast_1d(r) + m = jnp.atleast_1d(m) + l = jnp.atleast_1d(l) + dr = int(dr) + + out = jnp.zeros((r.size, m.size)) + r_jacobi = 1 - 2 * r**2 + m = jnp.abs(m) + n = ((l - m) // 2).astype(int) + dxs = jnp.arange(0, dr + 1) + + M_max = jnp.max(m) + # Loop over every different m value. There is another nested + # loop which will execute necessary n values. + out = fori_loop(0, (M_max + 1).astype(int), body, (out)) + return out + + @custom_jvp @jit def zernike_radial_switch_gpu(r, l, m, dr=0): @@ -2626,6 +2799,19 @@ def _zernike_radial_jvp(x, xdot): return f, (df.T * rdot).T + 0 * ldot + 0 * mdot + 0 * drdot +@zernike_radial_jvp_gpu.defjvp +def _zernike_radial_jvp_gpu_jvp(x, xdot): + (r, l, m, dr) = x + (rdot, ldot, mdot, drdot) = xdot + f = zernike_radial_jvp_gpu(r, l, m, dr) + df = zernike_radial_jvp_gpu(r, l, m, dr + 1) + # in theory l, m, dr aren't differentiable (they're integers) + # but marking them as non-diff argnums seems to cause escaped tracer values. + # probably a more elegant fix, but just setting those derivatives to zero seems + # to work fine. + return f, (df.T * rdot).T + 0 * ldot + 0 * mdot + 0 * drdot + + @zernike_radial_switch_gpu.defjvp def _zernike_radial_switch_gpu_jvp(x, xdot): (r, l, m, dr) = x