diff --git a/docs/getting_started/installation.md b/docs/getting_started/installation.md index 2b495e41..7c5264b4 100644 --- a/docs/getting_started/installation.md +++ b/docs/getting_started/installation.md @@ -37,6 +37,31 @@ a dependency by default. You need to have `blackjax` installed if you want to us pip install blackjax ``` +### Sampling with JAX support for GPU + +The `nuts_numpyro` sampler uses JAX as the backend and thus can support sampling on nvidia +GPU. The only thing you need to do to take advantage of this is to install JAX with CUDA +support before installing HSSM. Here's one example: + +```bash +python -m venv .venv # Create a virtual environment +source .venv/bin/activate # Activate the virtual environment + +pip install --upgrade pip + +# We need to limit the version of JAX for now due to some breaking +# changes introduced in JAX 0.4.16. +pip install --upgrade "jax[cuda11_pip]<0.4.16" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pip install hssm +``` + +The example above shows how to install JAX with CUDA 11 support. Please refer to the +[JAX Installation](https://jax.readthedocs.io/en/latest/installation.html) page for more +details on installing JAX on different platforms with GPU or TPU support. + +Note that on Google Colab, JAX support for GPU is enabled by default if the Colab backend +has GPU enabled. You simply need only install HSSM. + ### Visualizing the model Model graphs are created with `model.graph()` through `graphviz`. In order to use it,