Skip to content

Commit

Permalink
added documentation for GPU support
Browse files Browse the repository at this point in the history
  • Loading branch information
digicosmos86 committed Dec 13, 2023
1 parent 67575a0 commit 6ae87ad
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions docs/getting_started/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@ a dependency by default. You need to have `blackjax` installed if you want to us
pip install blackjax
```

### Sampling with JAX support for GPU

The `nuts_numpyro` sampler uses JAX as the backend and thus can support sampling on nvidia
GPU. The only thing you need to do to take advantage of this is to install JAX with CUDA
support before installing HSSM. Here's one example:

```bash
python -m venv .venv # Create a virtual environment
source .venv/bin/activate # Activate the virtual environment

pip install --upgrade pip

# We need to limit the version of JAX for now due to some breaking
# changes introduced in JAX 0.4.16.
pip install --upgrade "jax[cuda11_pip]<0.4.16" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install hssm
```

The example above shows how to install JAX with CUDA 11 support. Please refer to the
[JAX Installation](https://jax.readthedocs.io/en/latest/installation.html) page for more
details on installing JAX on different platforms with GPU or TPU support.

Note that on Google Colab, JAX support for GPU is enabled by default if the Colab backend
has GPU enabled. You simply need only install HSSM.

### Visualizing the model

Model graphs are created with `model.graph()` through `graphviz`. In order to use it,
Expand Down

0 comments on commit 6ae87ad

Please sign in to comment.