Skip to content

Commit

Permalink
fix jax installation for CUDA12
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Feb 29, 2024
1 parent 335acbc commit d01c101
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ jax_cpu =
# JAX GPU
jax_gpu =
%(jax_core_deps)s
jax[cuda]==0.4.10
jaxlib==0.4.10+cuda11.cudnn86
jax[cuda12_local]==0.4.10

# PyTorch CPU
pytorch_cpu =
Expand Down

0 comments on commit d01c101

Please sign in to comment.