Speedup, relative to the CPU mp_nerf implementation, of different computational methods for single chains
This can be reproduced with notebooks/benchmark_single_chain_reconstruction.ipynb
.
Leveraging the automatic vectorization feature of JAX the reconstruction was parallelized, running in 3.4 ms on GPU. Extrapolation of the torch implementation gives ~60 seconds in previous implementations, approximately 17,000x faster as the torch has no parallel chain implementation so has to be computed serially. This can be reproduced with notebooks/benchmark_multiple_chain_reconstruction.ipynb
.
git clone https://github.com/PeptoneLtd/nerfax.git && pip install ./nerfax[optional]
Note: for running on GPU, a GPU version of JAX must be installed, please follows the instructions at JAX GPU compatibility instructions
We also provide a Dockerfile which can be used to install NerFax. The dockerfile includes the GPU version of JAX.